From 5095fbc12bf57b4fd795a2ca3e2f723a644b0447 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Thu, 24 Oct 2024 10:39:48 -0400 Subject: [PATCH] Introduce checkstyle plugin - Based on https://github.com/spring-io/spring-javaformat - In this iteration, checkstyles are only enabled for spring-ai-core --- .devcontainer/scripts/onCreateCommand.sh | 16 + .editorconfig | 2 + .mvn/extensions.xml | 16 + .mvn/wrapper/maven-wrapper.properties | 27 +- document-readers/markdown-reader/pom.xml | 16 + .../markdown/MarkdownDocumentReader.java | 89 +- .../config/MarkdownDocumentReaderConfig.java | 30 +- .../markdown/MarkdownDocumentReaderTest.java | 23 +- document-readers/pdf-reader/pom.xml | 16 + .../ai/reader/pdf/PagePdfDocumentReader.java | 27 +- .../pdf/ParagraphPdfDocumentReader.java | 23 +- .../reader/pdf/aot/PdfReaderRuntimeHints.java | 11 +- .../reader/pdf/config/ParagraphManager.java | 63 +- .../pdf/config/PdfDocumentReaderConfig.java | 21 +- .../pdf/layout/ForkPDFLayoutTextStripper.java | 46 +- .../layout/PDFLayoutTextStripperByArea.java | 33 +- .../pdf/PagePdfDocumentReaderTests.java | 5 +- .../pdf/ParagraphPdfDocumentReaderTests.java | 5 +- .../pdf/aot/PdfReaderRuntimeHintsTests.java | 6 +- document-readers/tika-reader/pom.xml | 16 + .../ai/reader/tika/TikaDocumentReader.java | 5 +- .../reader/tika/TikaDocumentReaderTests.java | 5 +- models/spring-ai-anthropic/pom.xml | 16 + .../ai/anthropic/AnthropicChatModel.java | 35 +- .../ai/anthropic/AnthropicChatOptions.java | 193 +- .../anthropic/aot/AnthropicRuntimeHints.java | 5 +- .../ai/anthropic/api/AnthropicApi.java | 401 ++-- .../ai/anthropic/api/StreamHelper.java | 25 +- .../metadata/AnthropicRateLimit.java | 5 +- .../ai/anthropic/metadata/AnthropicUsage.java | 13 +- .../ai/anthropic/AnthropicChatModelIT.java | 60 +- .../AnthropicChatModelObservationIT.java | 23 +- .../anthropic/AnthropicTestConfiguration.java | 5 +- .../anthropic/ChatCompletionRequestTests.java | 5 +- .../ai/anthropic/EventParsingTests.java | 38 +- .../aot/AnthropicRuntimeHintsTests.java | 9 +- .../ai/anthropic/api/AnthropicApiIT.java | 15 +- .../api/tool/AnthropicApiLegacyToolIT.java | 25 +- .../api/tool/AnthropicApiToolIT.java | 25 +- .../api/tool/MockWeatherService.java | 57 +- .../ai/anthropic/api/tool/XmlHelper.java | 83 +- .../client/AnthropicChatClientIT.java | 50 +- .../application-logging-test.properties | 16 + models/spring-ai-azure-openai/pom.xml | 16 + .../AzureOpenAiAudioTranscriptionModel.java | 36 +- .../AzureOpenAiAudioTranscriptionOptions.java | 228 +-- .../ai/azure/openai/AzureOpenAiChatModel.java | 89 +- .../azure/openai/AzureOpenAiChatOptions.java | 300 +-- .../openai/AzureOpenAiEmbeddingModel.java | 18 +- .../openai/AzureOpenAiEmbeddingOptions.java | 116 +- .../azure/openai/AzureOpenAiImageModel.java | 35 +- .../azure/openai/AzureOpenAiImageOptions.java | 143 +- .../openai/AzureOpenAiResponseFormat.java | 5 +- .../ai/azure/openai/MergeUtils.java | 23 +- .../openai/aot/AzureOpenAiRuntimeHints.java | 5 +- ...nAiAudioTranscriptionResponseMetadata.java | 16 +- .../metadata/AzureOpenAiEmbeddingUsage.java | 16 +- .../AzureOpenAiImageGenerationMetadata.java | 34 +- .../AzureOpenAiImageResponseMetadata.java | 43 +- .../openai/metadata/AzureOpenAiUsage.java | 19 +- .../AzureChatCompletionsOptionsTests.java | 22 +- .../openai/AzureEmbeddingsOptionsTests.java | 5 +- .../AzureOpenAiAudioTranscriptionModelIT.java | 27 +- .../azure/openai/AzureOpenAiChatClientIT.java | 38 +- .../azure/openai/AzureOpenAiChatModelIT.java | 44 +- .../AzureOpenAiChatModelObservationIT.java | 34 +- .../openai/AzureOpenAiEmbeddingModelIT.java | 20 +- ...zureOpenAiEmbeddingModelObservationIT.java | 24 +- .../azure/openai/MockAiTestConfiguration.java | 33 +- .../MockAzureOpenAiTestConfiguration.java | 13 +- .../aot/AzureOpenAiRuntimeHintsTests.java | 5 +- .../AzureOpenAiChatModelFunctionCallIT.java | 48 +- .../openai/function/MockWeatherService.java | 59 +- .../openai/image/AzureOpenAiImageModelIT.java | 19 +- .../AzureOpenAiChatModelMetadataTests.java | 9 +- models/spring-ai-bedrock/pom.xml | 16 + .../ai/bedrock/BedrockUsage.java | 13 +- .../ai/bedrock/MessageToPromptConverter.java | 13 +- .../anthropic/AnthropicChatOptions.java | 100 +- .../anthropic/BedrockAnthropicChatModel.java | 15 +- .../api/AnthropicChatBedrockApi.java | 115 +- .../anthropic3/Anthropic3ChatOptions.java | 102 +- .../BedrockAnthropic3ChatModel.java | 21 +- .../api/Anthropic3ChatBedrockApi.java | 177 +- .../ai/bedrock/aot/BedrockRuntimeHints.java | 5 +- .../ai/bedrock/api/AbstractBedrockApi.java | 40 +- .../cohere/BedrockCohereChatModel.java | 11 +- .../cohere/BedrockCohereChatOptions.java | 133 +- .../cohere/BedrockCohereEmbeddingModel.java | 5 +- .../cohere/BedrockCohereEmbeddingOptions.java | 45 +- .../cohere/api/CohereChatBedrockApi.java | 162 +- .../cohere/api/CohereEmbeddingBedrockApi.java | 70 +- .../BedrockAi21Jurassic2ChatModel.java | 24 +- .../BedrockAi21Jurassic2ChatOptions.java | 102 +- .../api/Ai21Jurassic2ChatBedrockApi.java | 114 +- .../bedrock/llama/BedrockLlamaChatModel.java | 11 +- .../llama/BedrockLlamaChatOptions.java | 65 +- .../llama/api/LlamaChatBedrockApi.java | 195 +- .../bedrock/titan/BedrockTitanChatModel.java | 11 +- .../titan/BedrockTitanChatOptions.java | 84 +- .../titan/BedrockTitanEmbeddingModel.java | 21 +- .../titan/BedrockTitanEmbeddingOptions.java | 39 +- .../titan/api/TitanChatBedrockApi.java | 149 +- .../titan/api/TitanEmbeddingBedrockApi.java | 79 +- .../BedrockAnthropicChatModelIT.java | 31 +- .../BedrockAnthropicCreateRequestTests.java | 7 +- .../api/AnthropicChatBedrockApiIT.java | 17 +- .../BedrockAnthropic3ChatModelIT.java | 36 +- .../BedrockAnthropic3CreateRequestTests.java | 16 +- .../api/Anthropic3ChatBedrockApiIT.java | 36 +- .../bedrock/aot/BedrockRuntimeHintsTests.java | 14 +- .../BedrockCohereChatCreateRequestTests.java | 7 +- .../cohere/BedrockCohereChatModelIT.java | 30 +- .../cohere/BedrockCohereEmbeddingModelIT.java | 23 +- .../cohere/api/CohereChatBedrockApiIT.java | 17 +- .../api/CohereEmbeddingBedrockApiIT.java | 11 +- .../BedrockAi21Jurassic2ChatModelIT.java | 24 +- .../api/Ai21Jurassic2ChatBedrockApiIT.java | 8 +- .../llama/BedrockLlamaChatModelIT.java | 30 +- .../llama/BedrockLlamaCreateRequestTests.java | 12 +- .../llama/api/LlamaChatBedrockApiIT.java | 20 +- ...drockTitanChatModelCreateRequestTests.java | 7 +- .../titan/BedrockTitanChatModelIT.java | 4 +- .../titan/BedrockTitanEmbeddingModelIT.java | 19 +- .../titan/api/TitanChatBedrockApiIT.java | 9 +- .../titan/api/TitanEmbeddingBedrockApiIT.java | 9 +- models/spring-ai-huggingface/pom.xml | 16 + .../ai/huggingface/HuggingfaceChatModel.java | 18 +- .../HuggingfaceTestConfiguration.java | 5 +- .../ai/huggingface/client/ClientIT.java | 12 +- models/spring-ai-minimax/pom.xml | 16 + .../ai/minimax/MiniMaxChatModel.java | 100 +- .../ai/minimax/MiniMaxChatOptions.java | 415 +++-- .../ai/minimax/MiniMaxEmbeddingModel.java | 12 +- .../ai/minimax/MiniMaxEmbeddingOptions.java | 36 +- .../ai/minimax/aot/MiniMaxRuntimeHints.java | 5 +- .../ai/minimax/api/MiniMaxApi.java | 420 ++--- .../ai/minimax/api/MiniMaxApiConstants.java | 16 + .../MiniMaxStreamFunctionCallingHelper.java | 11 +- .../ai/minimax/metadata/MiniMaxUsage.java | 13 +- .../minimax/ChatCompletionRequestTests.java | 10 +- .../ai/minimax/MiniMaxTestConfiguration.java | 5 +- .../ai/minimax/api/MiniMaxApiIT.java | 21 +- .../api/MiniMaxApiToolFunctionCallIT.java | 43 +- .../ai/minimax/api/MiniMaxRetryTests.java | 104 +- .../ai/minimax/api/MockWeatherService.java | 65 +- .../chat/MiniMaxChatModelObservationIT.java | 22 +- .../minimax/chat/MiniMaxChatOptionsTests.java | 37 +- .../ai/minimax/embedding/EmbeddingIT.java | 22 +- .../MiniMaxEmbeddingModelObservationIT.java | 14 +- models/spring-ai-mistral-ai/pom.xml | 103 +- .../ai/mistralai/MistralAiChatModel.java | 51 +- .../ai/mistralai/MistralAiChatOptions.java | 257 +-- .../ai/mistralai/MistralAiEmbeddingModel.java | 13 +- .../mistralai/MistralAiEmbeddingOptions.java | 6 +- .../mistralai/aot/MistralAiRuntimeHints.java | 5 +- .../ai/mistralai/api/MistralAiApi.java | 434 ++--- .../MistralAiStreamFunctionCallingHelper.java | 9 +- .../ai/mistralai/metadata/MistralAiUsage.java | 24 +- .../ai/mistralai/MistralAiChatClientIT.java | 47 +- .../MistralAiChatCompletionRequestTest.java | 9 +- .../ai/mistralai/MistralAiChatModelIT.java | 38 +- .../MistralAiChatModelObservationIT.java | 22 +- .../ai/mistralai/MistralAiEmbeddingIT.java | 17 +- .../MistralAiEmbeddingModelObservationIT.java | 14 +- .../ai/mistralai/MistralAiRetryTests.java | 109 +- .../mistralai/MistralAiTestConfiguration.java | 5 +- .../ai/mistralai/MockWeatherService.java | 59 +- .../aot/MistralAiRuntimeHintsTests.java | 9 +- .../ai/mistralai/api/MistralAiApiIT.java | 13 +- .../tool/MistralAiApiToolFunctionCallIT.java | 41 +- .../api/tool/MockWeatherService.java | 59 +- .../tool/PaymentStatusFunctionCallingIT.java | 80 +- models/spring-ai-moonshot/pom.xml | 16 + .../ai/moonshot/MoonshotChatModel.java | 52 +- .../ai/moonshot/MoonshotChatOptions.java | 326 ++-- .../ai/moonshot/aot/MoonshotRuntimeHints.java | 5 +- .../ai/moonshot/api/MoonshotApi.java | 325 ++-- .../ai/moonshot/api/MoonshotConstants.java | 5 +- .../MoonshotStreamFunctionCallingHelper.java | 22 +- .../ai/moonshot/metadata/MoonshotUsage.java | 24 +- .../MoonshotChatCompletionRequestTest.java | 10 +- .../ai/moonshot/MoonshotRetryTests.java | 82 +- .../moonshot/MoonshotTestConfiguration.java | 5 +- .../aot/MoonshotRuntimeHintsTests.java | 10 +- .../ai/moonshot/api/MockWeatherService.java | 65 +- .../ai/moonshot/api/MoonshotApiIT.java | 18 +- .../api/MoonshotApiToolFunctionCallIT.java | 53 +- .../ai/moonshot/chat/ActorsFilms.java | 11 +- .../MoonshotChatModelFunctionCallingIT.java | 24 +- .../ai/moonshot/chat/MoonshotChatModelIT.java | 35 +- .../chat/MoonshotChatModelObservationIT.java | 22 +- models/spring-ai-oci-genai/pom.xml | 16 + .../ai/oci/OCIEmbeddingModel.java | 10 +- .../ai/oci/OCIEmbeddingOptions.java | 80 +- .../ai/oci/BaseEmbeddingModelTest.java | 5 +- .../ai/oci/OCIEmbeddingModelIT.java | 18 +- models/spring-ai-ollama/pom.xml | 16 + .../ai/ollama/OllamaChatModel.java | 59 +- .../ai/ollama/OllamaEmbeddingModel.java | 24 +- .../ai/ollama/aot/OllamaRuntimeHints.java | 5 +- .../ai/ollama/api/OllamaApi.java | 489 ++--- .../ai/ollama/api/OllamaModel.java | 5 +- .../ai/ollama/api/OllamaOptions.java | 195 +- .../management/ModelManagementOptions.java | 8 +- .../ollama/management/OllamaModelManager.java | 58 +- .../ollama/management/PullModelStrategy.java | 5 +- .../ai/ollama/management/package-info.java | 4 +- .../ai/ollama/metadata/OllamaChatUsage.java | 20 +- .../ollama/metadata/OllamaEmbeddingUsage.java | 15 +- .../ai/ollama/BaseOllamaIT.java | 3 +- .../OllamaChatModelFunctionCallingIT.java | 25 +- .../ai/ollama/OllamaChatModelIT.java | 47 +- .../ollama/OllamaChatModelMultimodalIT.java | 19 +- .../ollama/OllamaChatModelObservationIT.java | 22 +- .../ai/ollama/OllamaChatRequestTests.java | 20 +- .../ai/ollama/OllamaEmbeddingModelIT.java | 30 +- .../OllamaEmbeddingModelObservationIT.java | 15 +- .../ai/ollama/OllamaEmbeddingModelTests.java | 62 +- .../ollama/OllamaEmbeddingRequestTests.java | 14 +- .../ai/ollama/OllamaImage.java | 5 +- .../ollama/aot/OllamaRuntimeHintsTests.java | 10 +- .../ai/ollama/api/OllamaApiIT.java | 20 +- .../ai/ollama/api/OllamaApiModelsIT.java | 14 +- .../ollama/api/OllamaModelOptionsTests.java | 9 +- .../ollama/api/tool/MockWeatherService.java | 61 +- .../api/tool/OllamaApiToolFunctionCallIT.java | 26 +- .../management/OllamaModelManagerIT.java | 16 +- models/spring-ai-openai/pom.xml | 16 + .../ai/openai/ImageResponseMetadata.java | 16 + .../ai/openai/OpenAiAudioSpeechModel.java | 21 +- .../ai/openai/OpenAiAudioSpeechOptions.java | 186 +- .../openai/OpenAiAudioTranscriptionModel.java | 29 +- .../OpenAiAudioTranscriptionOptions.java | 113 +- .../ai/openai/OpenAiChatModel.java | 19 +- .../ai/openai/OpenAiChatOptions.java | 399 ++-- .../ai/openai/OpenAiEmbeddingModel.java | 12 +- .../ai/openai/OpenAiEmbeddingOptions.java | 73 +- .../ai/openai/OpenAiImageModel.java | 12 +- .../ai/openai/OpenAiImageOptions.java | 152 +- .../ai/openai/OpenAiModerationModel.java | 25 +- .../ai/openai/OpenAiModerationOptions.java | 26 +- .../ai/openai/aot/OpenAiRuntimeHints.java | 9 +- .../ai/openai/api/OpenAiApi.java | 558 +++--- .../ai/openai/api/OpenAiAudioApi.java | 406 ++-- .../ai/openai/api/OpenAiImageApi.java | 40 +- .../ai/openai/api/OpenAiModerationApi.java | 49 +- .../OpenAiStreamFunctionCallingHelper.java | 9 +- .../common/OpenAiApiClientErrorException.java | 5 +- .../openai/api/common/OpenAiApiConstants.java | 16 + .../ai/openai/audio/speech/Speech.java | 27 +- .../ai/openai/audio/speech/SpeechMessage.java | 19 +- .../ai/openai/audio/speech/SpeechModel.java | 4 +- .../ai/openai/audio/speech/SpeechPrompt.java | 27 +- .../openai/audio/speech/SpeechResponse.java | 30 +- .../audio/speech/StreamingSpeechModel.java | 7 +- .../OpenAiImageGenerationMetadata.java | 23 +- .../OpenAiModerationGenerationMetadata.java | 4 +- .../ai/openai/metadata/OpenAiRateLimit.java | 5 +- .../ai/openai/metadata/OpenAiUsage.java | 13 +- .../audio/OpenAiAudioSpeechMetadata.java | 5 +- .../OpenAiAudioSpeechResponseMetadata.java | 34 +- ...nAiAudioTranscriptionResponseMetadata.java | 29 +- .../support/OpenAiApiResponseHeaders.java | 5 +- .../OpenAiResponseHeaderExtractor.java | 5 +- .../ai/openai/ChatCompletionRequestTests.java | 5 +- .../ai/openai/OpenAiImageOptionsTests.java | 5 +- .../ai/openai/OpenAiTestConfiguration.java | 7 +- .../ai/openai/TranscriptionRequestTests.java | 7 +- .../ai/openai/acme/AcmeIT.java | 29 +- .../openai/aot/OpenAiRuntimeHintsTests.java | 10 +- .../ai/openai/api/OpenAiApiIT.java | 11 +- .../openai/api/tool/MockWeatherService.java | 61 +- .../api/tool/OpenAiApiToolFunctionCallIT.java | 35 +- .../ai/openai/audio/api/OpenAiAudioApiIT.java | 17 +- .../audio/speech/OpenAiSpeechModelIT.java | 23 +- ...hModelWithSpeechResponseMetadataTests.java | 25 +- .../OpenAiTranscriptionModelIT.java | 15 +- ...ithTranscriptionResponseMetadataTests.java | 22 +- .../TranscriptionModelTests.java | 5 +- .../ai/openai/chat/ActorsFilms.java | 11 +- .../openai/chat/MessageTypeContentTests.java | 50 +- ...OpenAiChatModeAdditionalHttpHeadersIT.java | 12 +- .../OpenAiChatModelFunctionCallingIT.java | 26 +- .../ai/openai/chat/OpenAiChatModelIT.java | 60 +- .../chat/OpenAiChatModelObservationIT.java | 23 +- .../chat/OpenAiChatModelProxyToolCallsIT.java | 108 +- .../chat/OpenAiChatModelResponseFormatIT.java | 54 +- ...delTypeReferenceBeanOutputConverterIT.java | 22 +- ...hatModelWithChatResponseMetadataTests.java | 15 +- .../chat/OpenAiCompatibleChatModelIT.java | 24 +- .../chat/OpenAiPaymentTransactionIT.java | 115 +- .../ai/openai/chat/OpenAiRetryTests.java | 133 +- .../chat/client/OpenAiChatClientIT.java | 48 +- ...enAiChatClientMultipleFunctionCallsIT.java | 53 +- .../openai/chat/client/ReReadingAdvisor.java | 9 +- .../chat/proxy/GroqWithOpenAiChatModelIT.java | 48 +- .../proxy/MistralWithOpenAiChatModelIT.java | 50 +- .../proxy/NvidiaWithOpenAiChatModelIT.java | 42 +- .../proxy/OllamaWithOpenAiChatModelIT.java | 66 +- .../ai/openai/embedding/EmbeddingIT.java | 28 +- .../OpenAiEmbeddingModelObservationIT.java | 14 +- .../ai/openai/image/OpenAiImageModelIT.java | 7 +- .../image/OpenAiImageModelObservationIT.java | 10 +- ...geModelWithImageResponseMetadataTests.java | 18 +- .../ai/openai/metadata/OpenAiUsageTests.java | 17 + .../OpenAiResponseHeaderExtractorTests.java | 9 +- .../moderation/OpenAiModerationModelIT.java | 16 +- .../OpenAiModerationModelTests.java | 23 +- .../ai/openai/testutils/AbstractIT.java | 21 +- .../transformer/MetadataTransformerIT.java | 27 +- .../SimplePersistentVectorStoreIT.java | 28 +- .../application-logging-test.properties | 16 + models/spring-ai-postgresml/pom.xml | 16 + .../postgresml/PostgresMlEmbeddingModel.java | 61 +- .../PostgresMlEmbeddingOptions.java | 85 +- .../PostgresMlEmbeddingModelIT.java | 19 +- .../PostgresMlEmbeddingOptionsTests.java | 5 +- models/spring-ai-qianfan/pom.xml | 16 + .../ai/qianfan/QianFanChatModel.java | 26 +- .../ai/qianfan/QianFanChatOptions.java | 217 ++- .../ai/qianfan/QianFanEmbeddingModel.java | 10 +- .../ai/qianfan/QianFanEmbeddingOptions.java | 52 +- .../ai/qianfan/QianFanImageModel.java | 10 +- .../ai/qianfan/QianFanImageOptions.java | 118 +- .../ai/qianfan/aot/QianFanRuntimeHints.java | 5 +- .../ai/qianfan/api/QianFanApi.java | 220 +-- .../ai/qianfan/api/QianFanConstants.java | 5 +- .../ai/qianfan/api/QianFanImageApi.java | 39 +- .../ai/qianfan/api/QianFanUtils.java | 20 +- .../qianfan/api/auth/AccessTokenResponse.java | 17 + .../ai/qianfan/api/auth/AuthApi.java | 16 + .../qianfan/api/auth/QianFanAccessToken.java | 32 +- .../api/auth/QianFanAuthenticator.java | 28 +- .../ai/qianfan/metadata/QianFanUsage.java | 13 +- .../qianfan/ChatCompletionRequestTests.java | 6 +- .../ai/qianfan/QianFanTestConfiguration.java | 5 +- .../ai/qianfan/api/QianFanApiIT.java | 23 +- .../ai/qianfan/api/QianFanRetryTests.java | 114 +- .../ai/qianfan/chat/QianFanChatModelIT.java | 26 +- .../chat/QianFanChatModelObservationIT.java | 22 +- .../ai/qianfan/embedding/EmbeddingIT.java | 23 +- .../QianFanEmbeddingModelObservationIT.java | 14 +- .../ai/qianfan/image/QianFanImageModelIT.java | 8 +- .../image/QianFanImageModelObservationIT.java | 10 +- models/spring-ai-stability-ai/pom.xml | 16 + .../StabilityAiImageGenerationMetadata.java | 15 +- .../ai/stabilityai/StabilityAiImageModel.java | 47 +- .../ai/stabilityai/StyleEnum.java | 7 +- .../ai/stabilityai/api/StabilityAiApi.java | 41 +- .../api/StabilityAiImageOptions.java | 222 +-- .../ai/stabilityai/StabilityAiApiIT.java | 46 +- .../stabilityai/StabilityAiImageModelIT.java | 37 +- .../StabilityAiImageTestConfiguration.java | 5 +- models/spring-ai-transformers/pom.xml | 16 + .../ai/transformers/ResourceCacheService.java | 5 +- .../TransformersEmbeddingModel.java | 83 +- .../ResourceCacheServiceTests.java | 35 +- ...formersEmbeddingModelObservationTests.java | 17 +- .../TransformersEmbeddingModelTests.java | 5 +- .../ai/transformers/samples/ONNXSample.java | 7 +- models/spring-ai-vertex-ai-embedding/pom.xml | 16 + .../VertexAiEmbeddingConnectionDetails.java | 57 +- .../embedding/VertexAiEmbeddingUsage.java | 16 + .../embedding/VertexAiEmbeddingUtils.java | 104 +- .../VertexAiMultimodalEmbeddingModel.java | 43 +- .../VertexAiMultimodalEmbeddingModelName.java | 5 +- .../VertexAiMultimodalEmbeddingOptions.java | 95 +- .../text/VertexAiTextEmbeddingModel.java | 34 +- .../text/VertexAiTextEmbeddingModelName.java | 5 +- .../text/VertexAiTextEmbeddingOptions.java | 199 +- .../VertexAiMultimodalEmbeddingModelIT.java | 38 +- .../text/TestVertexAiTextEmbeddingModel.java | 19 +- .../text/VertexAiTextEmbeddingModelIT.java | 16 +- ...rtexAiTextEmbeddingModelObservationIT.java | 19 +- .../text/VertexAiTextEmbeddingRetryTests.java | 85 +- models/spring-ai-vertex-ai-gemini/pom.xml | 16 + .../ai/vertexai/gemini/MimeTypeDetector.java | 35 +- .../gemini/VertexAiGeminiChatModel.java | 382 ++-- .../gemini/VertexAiGeminiChatOptions.java | 285 +-- .../aot/VertexAiGeminiRuntimeHints.java | 5 +- .../common/VertexAiGeminiConstants.java | 4 +- .../gemini/metadata/VertexAiUsage.java | 11 +- .../gemini/CreateGeminiRequestTests.java | 29 +- .../gemini/TestVertexAiGeminiChatModel.java | 15 +- .../VertexAiChatModelObservationIT.java | 25 +- .../gemini/VertexAiGeminiChatModelIT.java | 39 +- .../gemini/VertexAiGeminiRetryTests.java | 81 +- .../aot/VertexAiGeminiRuntimeHintsTests.java | 5 +- .../gemini/function/MockWeatherService.java | 59 +- ...texAiGeminiChatModelFunctionCallingIT.java | 42 +- .../VertexAiGeminiPaymentTransactionIT.java | 106 +- models/spring-ai-vertex-ai-palm2/pom.xml | 16 + .../palm2/VertexAiPaLm2ChatModel.java | 9 +- .../palm2/VertexAiPaLm2ChatOptions.java | 79 +- .../palm2/VertexAiPaLm2EmbeddingModel.java | 5 +- .../palm2/aot/VertexRuntimeHints.java | 5 +- .../vertexai/palm2/api/VertexAiPaLm2Api.java | 17 +- .../VertexAiPaLm2ChatGenerationClientIT.java | 24 +- .../palm2/VertexAiPaLm2ChatRequestTests.java | 11 +- .../palm2/VertexAiPaLm2EmbeddingModelIT.java | 17 +- .../palm2/aot/VertexRuntimeHintsTests.java | 9 +- .../palm2/api/VertexAiPaLm2ApiIT.java | 19 +- .../palm2/api/VertexAiPaLm2ApiTests.java | 40 +- models/spring-ai-watsonx-ai/pom.xml | 16 + .../ai/watsonx/WatsonxAiChatModel.java | 13 +- .../ai/watsonx/WatsonxAiChatOptions.java | 152 +- .../ai/watsonx/WatsonxAiEmbeddingModel.java | 30 +- .../ai/watsonx/WatsonxAiEmbeddingOptions.java | 43 +- .../ai/watsonx/aot/WatsonxAiRuntimeHints.java | 5 +- .../ai/watsonx/api/WatsonxAiApi.java | 17 +- .../ai/watsonx/api/WatsonxAiChatRequest.java | 18 +- .../ai/watsonx/api/WatsonxAiChatResponse.java | 11 +- .../ai/watsonx/api/WatsonxAiChatResults.java | 5 +- .../api/WatsonxAiEmbeddingRequest.java | 51 +- .../api/WatsonxAiEmbeddingResponse.java | 23 +- .../api/WatsonxAiEmbeddingResults.java | 19 +- .../utils/MessageToPromptConverter.java | 22 +- .../ai/watsonx/WatsonxAiChatModelTest.java | 17 +- .../watsonx/WatsonxAiEmbeddingModelTest.java | 44 +- .../aot/WatsonxAiRuntimeHintsTest.java | 9 +- .../watsonx/api/WatsonxAiChatOptionTest.java | 11 +- .../api/WatsonxAiEmbeddingOptionTest.java | 21 +- .../utils/MessageToPromptConverterTest.java | 35 +- models/spring-ai-zhipuai/pom.xml | 16 + .../ai/zhipuai/ZhiPuAiChatModel.java | 64 +- .../ai/zhipuai/ZhiPuAiChatOptions.java | 355 ++-- .../ai/zhipuai/ZhiPuAiEmbeddingModel.java | 14 +- .../ai/zhipuai/ZhiPuAiEmbeddingOptions.java | 36 +- .../ai/zhipuai/ZhiPuAiImageModel.java | 14 +- .../ai/zhipuai/ZhiPuAiImageOptions.java | 70 +- .../ai/zhipuai/aot/ZhiPuAiRuntimeHints.java | 5 +- .../ai/zhipuai/api/ZhiPuAiApi.java | 400 ++-- .../ai/zhipuai/api/ZhiPuAiImageApi.java | 39 +- .../ZhiPuAiStreamFunctionCallingHelper.java | 11 +- .../ai/zhipuai/api/ZhiPuApiConstants.java | 16 + .../ai/zhipuai/metadata/ZhiPuAiUsage.java | 13 +- .../zhipuai/ChatCompletionRequestTests.java | 10 +- .../ai/zhipuai/ZhiPuAiTestConfiguration.java | 5 +- .../ai/zhipuai/api/MockWeatherService.java | 65 +- .../ai/zhipuai/api/ZhiPuAiApiIT.java | 22 +- .../api/ZhiPuAiApiToolFunctionCallIT.java | 41 +- .../ai/zhipuai/api/ZhiPuAiRetryTests.java | 119 +- .../ai/zhipuai/chat/ActorsFilms.java | 11 +- .../ai/zhipuai/chat/ZhiPuAiChatModelIT.java | 65 +- .../chat/ZhiPuAiChatModelObservationIT.java | 22 +- .../ai/zhipuai/embedding/EmbeddingIT.java | 22 +- .../ZhiPuAiEmbeddingModelObservationIT.java | 14 +- .../ai/zhipuai/image/ZhiPuAiImageModelIT.java | 8 +- mvnw | 29 +- pom.xml | 65 + settings.xml | 16 + spring-ai-bom/pom.xml | 16 + spring-ai-core/pom.xml | 24 +- .../org/springframework/ai/ResourceUtils.java | 5 +- .../ai/aot/AiRuntimeHints.java | 25 +- .../ai/aot/KnuddelsRuntimeHints.java | 7 +- .../ai/aot/SpringAiCoreRuntimeHints.java | 20 +- .../transcription/AudioTranscription.java | 24 +- .../AudioTranscriptionMetadata.java | 6 +- .../AudioTranscriptionOptions.java | 5 +- .../AudioTranscriptionPrompt.java | 9 +- .../AudioTranscriptionResponse.java | 15 +- .../AudioTranscriptionResponseMetadata.java | 5 +- .../ai/chat/client/ChatClient.java | 9 +- .../ai/chat/client/ChatClientCustomizer.java | 2 +- .../ai/chat/client/DefaultChatClient.java | 218 +-- .../chat/client/DefaultChatClientBuilder.java | 6 +- .../chat/client/RequestResponseAdvisor.java | 10 +- .../ai/chat/client/ResponseEntity.java | 3 +- .../advisor/AbstractChatMemoryAdvisor.java | 13 +- .../advisor/DefaultAroundAdvisorChain.java | 53 +- .../LastMaxTokenSizeContentPurger.java | 8 +- .../advisor/MessageChatMemoryAdvisor.java | 16 +- .../advisor/PromptChatMemoryAdvisor.java | 16 +- .../client/advisor/QuestionAnswerAdvisor.java | 30 +- .../chat/client/advisor/SafeGuardAdvisor.java | 21 +- .../client/advisor/SimpleLoggerAdvisor.java | 17 +- .../advisor/VectorStoreChatMemoryAdvisor.java | 14 +- .../client/advisor/api/AdvisedRequest.java | 112 +- .../client/advisor/api/AdvisedResponse.java | 42 +- .../ai/chat/client/advisor/api/Advisor.java | 29 +- .../client/advisor/api/CallAroundAdvisor.java | 31 +- .../advisor/api/CallAroundAdvisorChain.java | 31 +- .../advisor/api/StreamAroundAdvisor.java | 31 +- .../advisor/api/StreamAroundAdvisorChain.java | 31 +- .../AdvisorObservationContext.java | 65 +- .../AdvisorObservationConvention.java | 31 +- .../AdvisorObservationDocumentation.java | 31 +- .../DefaultAdvisorObservationConvention.java | 35 +- .../advisor/observation/package-info.java | 4 +- ...atClientInputContentObservationFilter.java | 13 +- .../ChatClientObservationContext.java | 43 +- .../ChatClientObservationConvention.java | 31 +- .../ChatClientObservationDocumentation.java | 31 +- ...efaultChatClientObservationConvention.java | 37 +- .../chat/client/observation/package-info.java | 4 +- .../ai/chat/memory/ChatMemory.java | 4 +- .../ai/chat/memory/InMemoryChatMemory.java | 4 +- .../ai/chat/messages/AbstractMessage.java | 25 +- .../ai/chat/messages/AssistantMessage.java | 29 +- .../ai/chat/messages/Message.java | 5 +- .../ai/chat/messages/MessageType.java | 13 +- .../ai/chat/messages/SystemMessage.java | 22 +- .../ai/chat/messages/ToolResponseMessage.java | 29 +- .../ai/chat/messages/UserMessage.java | 9 +- .../chat/metadata/ChatGenerationMetadata.java | 5 +- .../chat/metadata/ChatResponseMetadata.java | 61 +- .../ai/chat/metadata/DefaultUsage.java | 33 +- .../ai/chat/metadata/EmptyRateLimit.java | 5 +- .../ai/chat/metadata/EmptyUsage.java | 5 +- .../ai/chat/metadata/PromptMetadata.java | 5 +- .../ai/chat/metadata/RateLimit.java | 5 +- .../ai/chat/metadata/Usage.java | 5 +- .../chat/model/AbstractToolCallSupport.java | 6 +- .../ai/chat/model/ChatModel.java | 10 +- .../ai/chat/model/ChatResponse.java | 33 +- .../ai/chat/model/Generation.java | 11 +- .../ai/chat/model/MessageAggregator.java | 9 +- .../ai/chat/model/StreamingChatModel.java | 5 +- .../ai/chat/model/ToolContext.java | 29 +- .../ChatModelCompletionObservationFilter.java | 6 +- ...ChatModelCompletionObservationHandler.java | 6 +- .../ChatModelMeterObservationHandler.java | 9 +- .../ChatModelObservationContentProcessor.java | 12 +- .../ChatModelObservationContext.java | 17 +- .../ChatModelObservationConvention.java | 5 +- .../ChatModelObservationDocumentation.java | 6 +- ...atModelPromptContentObservationFilter.java | 6 +- ...tModelPromptContentObservationHandler.java | 6 +- ...DefaultChatModelObservationConvention.java | 16 +- .../ai/chat/observation/package-info.java | 4 +- .../springframework/ai/chat/package-info.java | 8 +- .../chat/prompt/AssistantPromptTemplate.java | 11 +- .../ai/chat/prompt/ChatOptions.java | 9 +- .../ai/chat/prompt/ChatOptionsBuilder.java | 129 +- .../ai/chat/prompt/ChatPromptTemplate.java | 17 +- .../chat/prompt/FunctionPromptTemplate.java | 5 +- .../ai/chat/prompt/Prompt.java | 11 +- .../ai/chat/prompt/PromptTemplate.java | 17 +- .../ai/chat/prompt/PromptTemplateActions.java | 5 +- .../prompt/PromptTemplateChatActions.java | 9 +- .../prompt/PromptTemplateMessageActions.java | 11 +- .../prompt/PromptTemplateStringActions.java | 5 +- .../ai/chat/prompt/SystemPromptTemplate.java | 9 +- .../ai/chat/prompt/TemplateFormat.java | 13 +- ...tractConversionServiceOutputConverter.java | 5 +- .../AbstractMessageOutputConverter.java | 5 +- .../ai/converter/BeanOutputConverter.java | 68 +- .../ai/converter/FormatProvider.java | 5 +- .../ai/converter/ListOutputConverter.java | 5 +- .../ai/converter/MapOutputConverter.java | 5 +- .../springframework/ai/converter/README.md | 6 +- .../converter/StructuredOutputConverter.java | 5 +- .../ai/document/ContentFormatter.java | 5 +- .../ai/document/DefaultContentFormatter.java | 149 +- .../springframework/ai/document/Document.java | 229 +-- .../ai/document/DocumentReader.java | 5 +- .../ai/document/DocumentRetriever.java | 5 +- .../ai/document/DocumentTransformer.java | 5 +- .../ai/document/DocumentWriter.java | 5 +- .../ai/document/MetadataMode.java | 9 +- .../ai/document/id/IdGenerator.java | 5 +- .../document/id/JdkSha256HexIdGenerator.java | 7 +- .../ai/document/id/RandomIdGenerator.java | 5 +- .../ai/embedding/AbstractEmbeddingModel.java | 9 +- .../ai/embedding/BatchingStrategy.java | 5 +- .../ai/embedding/DocumentEmbeddingModel.java | 5 +- .../embedding/DocumentEmbeddingRequest.java | 5 +- .../ai/embedding/Embedding.java | 15 +- .../ai/embedding/EmbeddingModel.java | 11 +- .../ai/embedding/EmbeddingOptions.java | 3 +- .../ai/embedding/EmbeddingOptionsBuilder.java | 13 +- .../ai/embedding/EmbeddingRequest.java | 5 +- .../ai/embedding/EmbeddingResponse.java | 25 +- .../embedding/EmbeddingResponseMetadata.java | 9 +- .../ai/embedding/EmbeddingResultMetadata.java | 17 +- .../embedding/TokenCountBatchingStrategy.java | 11 +- ...ltEmbeddingModelObservationConvention.java | 10 +- ...EmbeddingModelMeterObservationHandler.java | 9 +- .../EmbeddingModelObservationContext.java | 17 +- .../EmbeddingModelObservationConvention.java | 5 +- ...mbeddingModelObservationDocumentation.java | 6 +- .../embedding/observation/package-info.java | 4 +- .../ai/evaluation/EvaluationRequest.java | 40 +- .../ai/evaluation/EvaluationResponse.java | 40 +- .../ai/evaluation/Evaluator.java | 22 +- .../ai/evaluation/FactCheckingEvaluator.java | 24 +- .../ai/evaluation/RelevancyEvaluator.java | 38 +- .../org/springframework/ai/image/Image.java | 21 +- .../ai/image/ImageGeneration.java | 12 +- .../ai/image/ImageGenerationMetadata.java | 5 +- .../ai/image/ImageMessage.java | 21 +- .../springframework/ai/image/ImageModel.java | 5 +- .../ai/image/ImageOptions.java | 5 +- .../ai/image/ImageOptionsBuilder.java | 107 +- .../springframework/ai/image/ImagePrompt.java | 26 +- .../ai/image/ImageResponse.java | 27 +- .../ai/image/ImageResponseMetadata.java | 5 +- ...efaultImageModelObservationConvention.java | 10 +- .../ImageModelObservationContext.java | 19 +- .../ImageModelObservationConvention.java | 5 +- .../ImageModelObservationDocumentation.java | 6 +- ...geModelPromptContentObservationFilter.java | 10 +- .../ai/image/observation/package-info.java | 4 +- .../ai/model/AbstractResponseMetadata.java | 22 +- .../ai/model/ChatModelDescription.java | 2 +- .../org/springframework/ai/model/Content.java | 16 + .../ai/model/EmbeddingModelDescription.java | 2 +- .../ai/model/EmbeddingUtils.java | 35 +- .../org/springframework/ai/model/Media.java | 11 +- .../ai/model/MediaContent.java | 16 + .../org/springframework/ai/model/Model.java | 5 +- .../ai/model/ModelDescription.java | 2 +- .../ai/model/ModelOptions.java | 5 +- .../ai/model/ModelOptionsUtils.java | 29 +- .../ai/model/ModelRequest.java | 7 +- .../ai/model/ModelResponse.java | 5 +- .../springframework/ai/model/ModelResult.java | 5 +- .../ai/model/MutableResponseMetadata.java | 24 +- .../ai/model/ResponseMetadata.java | 13 +- .../ai/model/ResultMetadata.java | 5 +- .../ai/model/StreamingModel.java | 5 +- .../function/AbstractFunctionCallback.java | 23 +- .../ai/model/function/FunctionCallback.java | 15 +- .../function/FunctionCallbackContext.java | 24 +- .../function/FunctionCallbackWrapper.java | 69 +- .../function/FunctionCallingOptions.java | 23 +- .../FunctionCallingOptionsBuilder.java | 23 +- .../ai/model/function/ToolCallHelper.java | 82 +- .../ai/model/function/TypeResolverHelper.java | 5 +- .../ErrorLoggingObservationHandler.java | 33 +- .../observation/ModelObservationContext.java | 6 +- .../ModelUsageMetricsGenerator.java | 15 +- .../ai/model/observation/package-info.java | 6 +- .../ai/model/package-info.java | 8 +- .../ai/moderation/Categories.java | 127 +- .../ai/moderation/CategoryScores.java | 136 +- .../ai/moderation/Generation.java | 10 +- .../ai/moderation/Moderation.java | 48 +- .../ModerationGenerationMetadata.java | 2 +- .../ai/moderation/ModerationMessage.java | 18 +- .../ai/moderation/ModerationModel.java | 2 +- .../ai/moderation/ModerationOptions.java | 2 +- .../moderation/ModerationOptionsBuilder.java | 38 +- .../ai/moderation/ModerationPrompt.java | 24 +- .../ai/moderation/ModerationResponse.java | 28 +- .../ModerationResponseMetadata.java | 2 +- .../ai/moderation/ModerationResult.java | 74 +- .../ai/observation/AiOperationMetadata.java | 9 +- .../conventions/AiObservationAttributes.java | 7 +- .../conventions/AiObservationEventNames.java | 7 +- .../AiObservationMetricAttributes.java | 7 +- .../conventions/AiObservationMetricNames.java | 7 +- .../conventions/AiOperationType.java | 5 +- .../observation/conventions/AiProvider.java | 5 +- .../observation/conventions/AiTokenType.java | 7 +- .../observation/conventions/SpringAiKind.java | 5 +- .../VectorStoreObservationAttributes.java | 7 +- .../VectorStoreObservationEventNames.java | 7 +- .../conventions/VectorStoreProvider.java | 29 +- .../VectorStoreSimilarityMetric.java | 29 +- .../observation/conventions/package-info.java | 6 +- .../ai/observation/package-info.java | 6 +- .../ai/observation/tracing/TracingHelper.java | 25 +- .../ai/reader/EmptyJsonMetadataGenerator.java | 5 +- .../ai/reader/ExtractedTextFormatter.java | 169 +- .../ai/reader/JsonMetadataGenerator.java | 7 +- .../springframework/ai/reader/JsonReader.java | 25 +- .../springframework/ai/reader/TextReader.java | 17 +- .../tokenizer/JTokkitTokenCountEstimator.java | 4 +- .../ai/tokenizer/TokenCountEstimator.java | 4 +- .../transformer/ContentFormatTransformer.java | 13 +- .../transformer/KeywordMetadataEnricher.java | 15 +- .../transformer/SummaryMetadataEnricher.java | 31 +- .../ai/transformer/splitter/TextSplitter.java | 14 +- .../splitter/TokenTextSplitter.java | 23 +- .../springframework/ai/util/JacksonUtils.java | 5 +- .../springframework/ai/util/ParsingUtils.java | 9 +- .../ai/vectorstore/SearchRequest.java | 38 +- .../ai/vectorstore/SimpleVectorStore.java | 40 +- .../ai/vectorstore/VectorStore.java | 5 +- .../ai/vectorstore/filter/Filter.java | 42 +- .../filter/FilterExpressionBuilder.java | 34 +- .../filter/FilterExpressionConverter.java | 7 +- .../filter/FilterExpressionTextParser.java | 9 +- .../ai/vectorstore/filter/FilterHelper.java | 13 +- .../filter/antlr4/FiltersBaseListener.java | 24 +- .../filter/antlr4/FiltersBaseVisitor.java | 24 +- .../filter/antlr4/FiltersLexer.java | 231 ++- .../filter/antlr4/FiltersListener.java | 24 +- .../filter/antlr4/FiltersParser.java | 1648 +++++++++-------- .../filter/antlr4/FiltersVisitor.java | 28 +- .../AbstractFilterExpressionConverter.java | 9 +- .../PineconeFilterExpressionConverter.java | 7 +- .../PrintFilterExpressionConverter.java | 7 +- .../AbstractObservationVectorStore.java | 37 +- ...faultVectorStoreObservationConvention.java | 35 +- ...ectorStoreObservationContentProcessor.java | 12 +- .../VectorStoreObservationContext.java | 117 +- .../VectorStoreObservationConvention.java | 31 +- .../VectorStoreObservationDocumentation.java | 37 +- ...orStoreQueryResponseObservationFilter.java | 9 +- ...rStoreQueryResponseObservationHandler.java | 6 +- .../vectorstore/observation/package-info.java | 6 +- .../ai/writer/FileDocumentWriter.java | 5 +- .../embedding-model-dimensions.properties | 15 + .../springframework/ai/TestConfiguration.java | 5 +- .../ai/aot/AiRuntimeHintsTests.java | 40 +- .../ai/aot/KnuddelsRuntimeHintsTest.java | 8 +- .../ai/aot/SpringAiCoreRuntimeHintsTest.java | 7 +- .../ai/chat/ChatBuilderTests.java | 10 +- .../ai/chat/ChatModelTests.java | 38 +- .../chat/client/ChatClientAdvisorTests.java | 41 +- .../client/ChatClientResponseEntityTests.java | 36 +- .../ai/chat/client/ChatClientTest.java | 150 +- .../ai/chat/client/advisor/AdvisorsTests.java | 171 +- .../advisor/QuestionAnswerAdvisorTests.java | 60 +- .../advisor/SimpleLoggerAdvisorTests.java | 22 +- .../AdvisorObservationContextTests.java | 9 +- ...aultAdvisorObservationConventionTests.java | 15 +- ...entInputContentObservationFilterTests.java | 31 +- .../ChatClientObservationContextTests.java | 13 +- ...tChatClientObservationConventionTests.java | 113 +- .../ai/chat/metadata/DefaultUsageTests.java | 30 +- .../ai/chat/model/GenerationTests.java | 33 +- ...ModelCompletionObservationFilterTests.java | 18 +- ...odelCompletionObservationHandlerTests.java | 10 +- ...ChatModelMeterObservationHandlerTests.java | 27 +- .../ChatModelObservationContextTests.java | 6 +- ...elPromptContentObservationFilterTests.java | 18 +- ...lPromptContentObservationHandlerTests.java | 6 +- ...ltChatModelObservationConventionTests.java | 10 +- .../ai/converter/BeanOutputConverterTest.java | 146 +- .../ai/converter/ListOutputConverterTest.java | 7 +- .../ai/document/ContentFormatterTests.java | 36 +- .../ai/document/DocumentBuilderTests.java | 102 +- .../document/id/IdGeneratorProviderTest.java | 7 +- .../id/JdkSha256HexIdGeneratorTest.java | 19 +- .../AbstractEmbeddingModelTests.java | 18 +- .../TokenCountBatchingStrategyTests.java | 9 +- ...eddingModelObservationConventionTests.java | 12 +- ...dingModelMeterObservationHandlerTests.java | 29 +- ...EmbeddingModelObservationContextTests.java | 10 +- ...tImageModelObservationConventionTests.java | 6 +- .../ImageModelObservationContextTests.java | 6 +- ...elPromptContentObservationFilterTests.java | 18 +- .../ai/metadata/PromptMetadataTests.java | 15 +- .../ai/metadata/UsageTests.java | 12 +- .../ai/model/ModelOptionsUtilsTests.java | 206 ++- .../function/StandaloneWeatherFunction.java | 2 +- .../model/function/TypeResolverHelperIT.java | 17 +- .../function/TypeResolverHelperTests.java | 17 +- .../ModelObservationContextTests.java | 17 + .../ModelUsageMetricsGeneratorTests.java | 14 +- .../observation/AiOperationMetadataTests.java | 5 +- .../tracing/TracingHelperTests.java | 25 +- .../springframework/ai/prompt/ChatTests.java | 5 +- .../ai/prompt/PromptTemplateTest.java | 48 +- .../ai/prompt/PromptTests.java | 16 +- .../ai/reader/JsonReaderTests.java | 26 +- .../ai/reader/TextReaderTests.java | 12 +- .../splitter/TextSplitterTests.java | 5 +- .../splitter/TokenTextSplitterTest.java | 23 +- .../filter/FilterExpressionBuilderTests.java | 30 +- .../FilterExpressionTextParserTests.java | 52 +- .../vectorstore/filter/FilterHelperTests.java | 7 +- .../filter/SearchRequestTests.java | 11 +- ...ineconeFilterExpressionConverterTests.java | 23 +- ...VectorStoreObservationConventionTests.java | 13 +- .../VectorStoreObservationContextTests.java | 9 +- ...reQueryResponseObservationFilterTests.java | 19 +- ...eQueryResponseObservationHandlerTests.java | 10 +- .../application-logging-test.properties | 15 + spring-ai-core/src/test/resources/bikes.json | 528 +++--- spring-ai-core/src/test/resources/logback.xml | 40 +- spring-ai-docs/pom.xml | 16 + spring-ai-docs/src/assembly/javadocs.xml | 16 + .../main/antora/modules/ROOT/images/no.svg | 35 +- .../spring-ai-integration-diagram-3.svg | 48 +- .../ROOT/images/spring_ai_logo_with_text.svg | 38 +- .../main/antora/modules/ROOT/images/yes.svg | 18 +- .../modules/ROOT/pages/api/advisors.adoc | 2 +- .../modules/ROOT/pages/api/aimetadata.adoc | 2 +- .../pages/api/audio/speech/openai-speech.adoc | 20 +- .../azure-openai-transcriptions.adoc | 12 +- .../transcriptions/openai-transcriptions.adoc | 12 +- .../ROOT/pages/api/chat/anthropic-chat.adoc | 22 +- .../pages/api/chat/azure-openai-chat.adoc | 14 +- .../api/chat/bedrock/bedrock-anthropic.adoc | 16 +- .../api/chat/bedrock/bedrock-anthropic3.adoc | 20 +- .../api/chat/bedrock/bedrock-cohere.adoc | 16 +- .../api/chat/bedrock/bedrock-jurassic2.adoc | 8 +- .../pages/api/chat/bedrock/bedrock-llama.adoc | 16 +- .../pages/api/chat/bedrock/bedrock-titan.adoc | 16 +- .../functions/anthropic-chat-functions.adoc | 4 +- .../azure-open-ai-chat-functions.adoc | 4 +- .../functions/minimax-chat-functions.adoc | 4 +- .../functions/mistralai-chat-functions.adoc | 4 +- .../functions/moonshot-chat-functions.adoc | 4 +- .../chat/functions/ollama-chat-functions.adoc | 4 +- .../chat/functions/openai-chat-functions.adoc | 8 +- .../vertexai-gemini-chat-functions.adoc | 4 +- .../functions/zhipuai-chat-functions.adoc | 4 +- .../ROOT/pages/api/chat/groq-chat.adoc | 10 +- .../ROOT/pages/api/chat/huggingface.adoc | 4 +- .../ROOT/pages/api/chat/minimax-chat.adoc | 26 +- .../ROOT/pages/api/chat/mistralai-chat.adoc | 18 +- .../ROOT/pages/api/chat/moonshot-chat.adoc | 18 +- .../ROOT/pages/api/chat/nvidia-chat.adoc | 4 +- .../ROOT/pages/api/chat/ollama-chat.adoc | 18 +- .../ROOT/pages/api/chat/openai-chat.adoc | 38 +- .../ROOT/pages/api/chat/qianfan-chat.adoc | 18 +- .../pages/api/chat/vertexai-gemini-chat.adoc | 12 +- .../pages/api/chat/vertexai-palm2-chat.adoc | 16 +- .../ROOT/pages/api/chat/watsonx-ai-chat.adoc | 4 +- .../ROOT/pages/api/chat/zhipuai-chat.adoc | 18 +- .../modules/ROOT/pages/api/chatclient.adoc | 16 +- .../embeddings/azure-openai-embeddings.adoc | 4 +- .../embeddings/bedrock-cohere-embedding.adoc | 6 +- .../embeddings/bedrock-titan-embedding.adoc | 10 +- .../api/embeddings/minimax-embeddings.adoc | 4 +- .../api/embeddings/mistralai-embeddings.adoc | 4 +- .../api/embeddings/oci-genai-embeddings.adoc | 14 +- .../api/embeddings/ollama-embeddings.adoc | 4 +- .../ROOT/pages/api/embeddings/onnx.adoc | 2 +- .../api/embeddings/openai-embeddings.adoc | 4 +- .../api/embeddings/postgresml-embeddings.adoc | 2 +- .../api/embeddings/qianfan-embeddings.adoc | 2 +- .../vertexai-embeddings-multimodal.adoc | 10 +- .../embeddings/vertexai-embeddings-palm2.adoc | 12 +- .../embeddings/vertexai-embeddings-text.adoc | 4 +- .../api/embeddings/zhipuai-embeddings.adoc | 4 +- .../modules/ROOT/pages/api/etl-pipeline.adoc | 26 +- .../modules/ROOT/pages/api/functions.adoc | 6 +- .../api/moderation/openai-moderation.adoc | 14 +- .../modules/ROOT/pages/api/multimodality.adoc | 4 +- .../antora/modules/ROOT/pages/api/prompt.adoc | 12 +- .../api/structured-output-converter.adoc | 36 +- .../modules/ROOT/pages/api/vectordbs.adoc | 6 +- .../pages/api/vectordbs/apache-cassandra.adoc | 2 +- .../pages/api/vectordbs/azure-cosmos-db.adoc | 18 +- .../ROOT/pages/api/vectordbs/chroma.adoc | 2 +- .../pages/api/vectordbs/elasticsearch.adoc | 2 +- .../ROOT/pages/api/vectordbs/hana.adoc | 12 +- .../ROOT/pages/api/vectordbs/milvus.adoc | 2 +- .../ROOT/pages/api/vectordbs/mongodb.adoc | 2 +- .../ROOT/pages/api/vectordbs/opensearch.adoc | 2 +- .../ROOT/pages/api/vectordbs/oracle.adoc | 2 +- .../ROOT/pages/api/vectordbs/pgvector.adoc | 2 +- .../ROOT/pages/api/vectordbs/pinecone.adoc | 2 +- .../ROOT/pages/api/vectordbs/qdrant.adoc | 2 +- .../ROOT/pages/api/vectordbs/redis.adoc | 2 +- .../ROOT/pages/api/vectordbs/typesense.adoc | 2 +- .../ROOT/pages/api/vectordbs/weaviate.adoc | 2 +- .../modules/ROOT/pages/upgrade-notes.adoc | 8 +- spring-ai-docs/src/main/javadoc/overview.html | 16 + spring-ai-retry/pom.xml | 16 + .../ai/retry/NonTransientAiException.java | 5 +- .../springframework/ai/retry/RetryUtils.java | 36 +- .../ai/retry/TransientAiException.java | 5 +- spring-ai-spring-boot-autoconfigure/pom.xml | 16 + .../anthropic/AnthropicAutoConfiguration.java | 9 +- .../anthropic/AnthropicChatProperties.java | 13 +- .../AnthropicConnectionProperties.java | 5 +- ...ureOpenAiAudioTranscriptionProperties.java | 9 +- .../openai/AzureOpenAiAutoConfiguration.java | 20 +- .../openai/AzureOpenAiChatProperties.java | 5 +- .../AzureOpenAiConnectionProperties.java | 12 +- .../AzureOpenAiEmbeddingProperties.java | 11 +- .../AzureOpenAiImageOptionsProperties.java | 20 +- .../BedrockAwsConnectionConfiguration.java | 5 +- .../BedrockAwsConnectionProperties.java | 17 +- ...BedrockAnthropicChatAutoConfiguration.java | 8 +- .../BedrockAnthropicChatProperties.java | 7 +- ...edrockAnthropic3ChatAutoConfiguration.java | 8 +- .../BedrockAnthropic3ChatProperties.java | 9 +- .../BedrockCohereChatAutoConfiguration.java | 8 +- .../cohere/BedrockCohereChatProperties.java | 5 +- ...drockCohereEmbeddingAutoConfiguration.java | 3 +- .../BedrockCohereEmbeddingProperties.java | 5 +- ...ockAi21Jurassic2ChatAutoConfiguration.java | 9 +- .../BedrockAi21Jurassic2ChatProperties.java | 4 +- .../BedrockLlamaChatAutoConfiguration.java | 7 +- .../llama/BedrockLlamaChatProperties.java | 5 +- .../BedrockTitanChatAutoConfiguration.java | 10 +- .../titan/BedrockTitanChatProperties.java | 11 +- ...edrockTitanEmbeddingAutoConfiguration.java | 5 +- .../BedrockTitanEmbeddingProperties.java | 21 +- .../client/ChatClientAutoConfiguration.java | 6 +- .../client/ChatClientBuilderConfigurer.java | 2 +- .../client/ChatClientBuilderProperties.java | 7 +- .../memory/CommonChatMemoryProperties.java | 7 +- .../CassandraChatMemoryAutoConfiguration.java | 7 +- .../CassandraChatMemoryProperties.java | 12 +- .../ChatObservationAutoConfiguration.java | 30 +- .../ChatObservationProperties.java | 9 +- .../chat/observation/package-info.java | 4 +- ...EmbeddingObservationAutoConfiguration.java | 6 +- .../embedding/observation/package-info.java | 4 +- .../HuggingfaceChatAutoConfiguration.java | 5 +- .../HuggingfaceChatProperties.java | 11 +- .../ImageObservationAutoConfiguration.java | 6 +- .../ImageObservationProperties.java | 7 +- .../image/observation/package-info.java | 4 +- .../minimax/MiniMaxAutoConfiguration.java | 9 +- .../minimax/MiniMaxChatProperties.java | 7 +- .../minimax/MiniMaxConnectionProperties.java | 5 +- .../minimax/MiniMaxEmbeddingProperties.java | 5 +- .../minimax/MiniMaxParentProperties.java | 9 +- .../mistralai/MistralAiAutoConfiguration.java | 9 +- .../mistralai/MistralAiChatProperties.java | 13 +- .../mistralai/MistralAiCommonProperties.java | 5 +- .../MistralAiEmbeddingProperties.java | 9 +- .../mistralai/MistralAiParentProperties.java | 5 +- .../moonshot/MoonshotAutoConfiguration.java | 9 +- .../moonshot/MoonshotChatProperties.java | 5 +- .../moonshot/MoonshotCommonProperties.java | 5 +- .../moonshot/MoonshotParentProperties.java | 5 +- .../oci/genai/OCIConnectionProperties.java | 61 +- .../genai/OCIEmbeddingModelProperties.java | 24 +- .../oci/genai/OCIGenAiAutoConfiguration.java | 40 +- .../autoconfigure/oci/genai/ServingMode.java | 7 +- .../ollama/OllamaAutoConfiguration.java | 25 +- .../ollama/OllamaChatProperties.java | 13 +- .../ollama/OllamaConnectionDetails.java | 5 +- .../ollama/OllamaConnectionProperties.java | 7 +- .../ollama/OllamaEmbeddingProperties.java | 13 +- .../OllamaInitializationProperties.java | 35 +- .../openai/OpenAiAudioSpeechProperties.java | 8 +- .../OpenAiAudioTranscriptionProperties.java | 9 +- .../openai/OpenAiAutoConfiguration.java | 72 +- .../openai/OpenAiChatProperties.java | 13 +- .../openai/OpenAiConnectionProperties.java | 5 +- .../openai/OpenAiEmbeddingProperties.java | 7 +- .../openai/OpenAiImageProperties.java | 7 +- .../openai/OpenAiModerationProperties.java | 4 +- .../openai/OpenAiParentProperties.java | 9 +- .../PostgresMlAutoConfiguration.java | 5 +- .../PostgresMlEmbeddingProperties.java | 5 +- .../qianfan/QianFanAutoConfiguration.java | 9 +- .../qianfan/QianFanChatProperties.java | 7 +- .../qianfan/QianFanConnectionProperties.java | 5 +- .../qianfan/QianFanEmbeddingProperties.java | 5 +- .../qianfan/QianFanImageProperties.java | 7 +- .../qianfan/QianFanParentProperties.java | 11 +- .../retry/SpringAiRetryAutoConfiguration.java | 10 +- .../retry/SpringAiRetryProperties.java | 83 +- .../StabilityAiConnectionProperties.java | 5 +- .../StabilityAiImageAutoConfiguration.java | 5 +- .../StabilityAiImageProperties.java | 12 +- .../StabilityAiParentProperties.java | 9 +- ...ormersEmbeddingModelAutoConfiguration.java | 13 +- .../TransformersEmbeddingModelProperties.java | 89 +- .../CommonVectorStoreProperties.java | 5 +- .../AzureVectorStoreAutoConfiguration.java | 7 +- .../azure/AzureVectorStoreProperties.java | 15 +- ...CassandraVectorStoreAutoConfiguration.java | 5 +- .../CassandraVectorStoreProperties.java | 5 +- .../chroma/ChromaApiProperties.java | 7 +- .../chroma/ChromaConnectionDetails.java | 5 +- .../ChromaVectorStoreAutoConfiguration.java | 9 +- .../chroma/ChromaVectorStoreProperties.java | 5 +- .../CosmosDBVectorStoreAutoConfiguration.java | 8 +- .../CosmosDBVectorStoreProperties.java | 4 +- ...ticsearchVectorStoreAutoConfiguration.java | 5 +- .../ElasticsearchVectorStoreProperties.java | 9 +- .../gemfire/GemFireConnectionDetails.java | 5 +- .../GemFireVectorStoreAutoConfiguration.java | 4 +- .../gemfire/GemFireVectorStoreProperties.java | 18 +- ...HanaCloudVectorStoreAutoConfiguration.java | 11 +- .../HanaCloudVectorStoreProperties.java | 9 +- .../MilvusServiceClientConnectionDetails.java | 5 +- .../milvus/MilvusServiceClientProperties.java | 59 +- .../MilvusVectorStoreAutoConfiguration.java | 8 +- .../milvus/MilvusVectorStoreProperties.java | 89 +- ...goDBAtlasVectorStoreAutoConfiguration.java | 12 +- .../MongoDBAtlasVectorStoreProperties.java | 9 +- .../Neo4jVectorStoreAutoConfiguration.java | 5 +- .../neo4j/Neo4jVectorStoreProperties.java | 5 +- ...ctorStoreObservationAutoConfiguration.java | 16 +- .../VectorStoreObservationProperties.java | 5 +- .../vectorstore/observation/package-info.java | 4 +- .../OpenSearchConnectionDetails.java | 9 +- ...penSearchVectorStoreAutoConfiguration.java | 23 +- .../OpenSearchVectorStoreProperties.java | 17 +- .../OracleVectorStoreAutoConfiguration.java | 6 +- .../oracle/OracleVectorStoreProperties.java | 19 +- .../PgVectorStoreAutoConfiguration.java | 8 +- .../pgvector/PgVectorStoreProperties.java | 11 +- .../PineconeVectorStoreAutoConfiguration.java | 6 +- .../PineconeVectorStoreProperties.java | 5 +- .../qdrant/QdrantConnectionDetails.java | 5 +- .../QdrantVectorStoreAutoConfiguration.java | 2 +- .../qdrant/QdrantVectorStoreProperties.java | 5 +- .../RedisVectorStoreAutoConfiguration.java | 8 +- .../redis/RedisVectorStoreProperties.java | 5 +- .../typesense/TypesenseConnectionDetails.java | 5 +- .../TypesenseServiceClientProperties.java | 10 +- ...TypesenseVectorStoreAutoConfiguration.java | 20 +- .../TypesenseVectorStoreProperties.java | 6 +- .../weaviate/WeaviateConnectionDetails.java | 5 +- .../WeaviateVectorStoreAutoConfiguration.java | 2 +- .../WeaviateVectorStoreProperties.java | 28 +- .../VertexAiEmbeddingAutoConfiguration.java | 12 +- ...VertexAiEmbeddingConnectionProperties.java | 5 +- ...VertexAiMultimodalEmbeddingProperties.java | 5 +- .../VertexAiTextEmbeddingProperties.java | 5 +- .../VertexAiGeminiAutoConfiguration.java | 12 +- .../gemini/VertexAiGeminiChatProperties.java | 5 +- .../VertexAiGeminiConnectionProperties.java | 37 +- .../palm2/VertexAiPalm2AutoConfiguration.java | 5 +- .../VertexAiPalm2ConnectionProperties.java | 5 +- .../VertexAiPalm2EmbeddingProperties.java | 5 +- .../palm2/VertexAiPlam2ChatProperties.java | 5 +- .../watsonxai/WatsonxAiAutoConfiguration.java | 5 +- .../watsonxai/WatsonxAiChatProperties.java | 9 +- .../WatsonxAiConnectionProperties.java | 17 +- .../WatsonxAiEmbeddingProperties.java | 24 +- .../zhipuai/ZhiPuAiAutoConfiguration.java | 9 +- .../zhipuai/ZhiPuAiChatProperties.java | 7 +- .../zhipuai/ZhiPuAiConnectionProperties.java | 5 +- .../zhipuai/ZhiPuAiEmbeddingProperties.java | 5 +- .../zhipuai/ZhiPuAiImageProperties.java | 7 +- .../zhipuai/ZhiPuAiParentProperties.java | 9 +- ...ot.autoconfigure.AutoConfiguration.imports | 16 + .../AnthropicAutoConfigurationIT.java | 17 +- .../anthropic/AnthropicPropertiesTests.java | 5 +- .../tool/FunctionCallWithFunctionBeanIT.java | 22 +- .../FunctionCallWithPromptFunctionIT.java | 16 +- .../anthropic/tool/MockWeatherService.java | 57 +- .../azure/AzureOpenAiAutoConfigurationIT.java | 66 +- ...eOpenAiAutoConfigurationPropertyTests.java | 5 +- ...OpenAiDirectOpenAiAutoConfigurationIT.java | 25 +- .../azure/tool/DeploymentNameUtil.java | 16 + .../tool/FunctionCallWithFunctionBeanIT.java | 26 +- .../FunctionCallWithFunctionWrapperIT.java | 14 +- .../FunctionCallWithPromptFunctionIT.java | 14 +- .../azure/tool/MockWeatherService.java | 57 +- .../BedrockAwsConnectionConfigurationIT.java | 17 +- ...drockAnthropicChatAutoConfigurationIT.java | 22 +- ...rockAnthropic3ChatAutoConfigurationIT.java | 22 +- .../BedrockCohereChatAutoConfigurationIT.java | 24 +- ...ockCohereEmbeddingAutoConfigurationIT.java | 16 +- ...kAi21Jurassic2ChatAutoConfigurationIT.java | 20 +- .../BedrockLlamaChatAutoConfigurationIT.java | 24 +- .../BedrockTitanChatAutoConfigurationIT.java | 21 +- ...rockTitanEmbeddingAutoConfigurationIT.java | 9 +- .../client/ChatClientAutoConfigurationIT.java | 14 +- ...ientObservationAutoConfigurationTests.java | 14 +- ...assandraChatMemoryAutoConfigurationIT.java | 21 +- .../CassandraChatMemoryPropertiesTest.java | 5 +- ...ChatObservationAutoConfigurationTests.java | 32 +- ...dingObservationAutoConfigurationTests.java | 10 +- .../HuggingfaceChatAutoConfigurationIT.java | 17 +- ...mageObservationAutoConfigurationTests.java | 10 +- .../minimax/FunctionCallbackInPromptIT.java | 24 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 34 +- .../minimax/FunctionCallbackWrapperIT.java | 24 +- .../minimax/MiniMaxAutoConfigurationIT.java | 20 +- .../minimax/MiniMaxPropertiesTests.java | 6 +- .../minimax/MockWeatherService.java | 65 +- .../MistralAiAutoConfigurationIT.java | 17 +- .../mistralai/MistralAiPropertiesTests.java | 5 +- .../mistralai/tool/PaymentStatusBeanIT.java | 49 +- .../tool/PaymentStatusBeanOpenAiIT.java | 49 +- .../mistralai/tool/PaymentStatusPromptIT.java | 51 +- .../tool/WeatherServicePromptIT.java | 49 +- .../moonshot/MoonshotAutoConfigurationIT.java | 18 +- .../moonshot/MoonshotPropertiesTests.java | 6 +- .../tool/FunctionCallbackInPromptIT.java | 24 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 34 +- .../tool/FunctionCallbackWrapperIT.java | 26 +- .../moonshot/tool/MockWeatherService.java | 61 +- .../genai/OCIGenAiAutoConfigurationIT.java | 12 +- .../ai/autoconfigure/ollama/BaseOllamaIT.java | 33 +- .../ollama/OllamaChatAutoConfigurationIT.java | 41 +- .../OllamaChatAutoConfigurationTests.java | 5 +- .../OllamaEmbeddingAutoConfigurationIT.java | 32 +- ...OllamaEmbeddingAutoConfigurationTests.java | 5 +- .../ai/autoconfigure/ollama/OllamaImage.java | 5 +- .../tool/FunctionCallbackInPromptIT.java | 29 +- .../tool/FunctionCallbackWrapperIT.java | 32 +- .../ollama/tool/MockWeatherService.java | 61 +- .../openai/OpenAiAutoConfigurationIT.java | 27 +- .../openai/OpenAiPropertiesTests.java | 11 +- .../OpenAiResponseFormatPropertiesTests.java | 10 +- .../tool/FunctionCallbackInPrompt2IT.java | 24 +- .../tool/FunctionCallbackInPromptIT.java | 21 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 8 +- .../tool/FunctionCallbackWrapper2IT.java | 20 +- .../tool/FunctionCallbackWrapperIT.java | 21 +- .../openai/tool/MockWeatherService.java | 61 +- .../PostgresMlAutoConfigurationIT.java | 13 +- .../PostgresMlEmbeddingPropertiesTests.java | 5 +- .../qianfan/QianFanAutoConfigurationIT.java | 26 +- .../qianfan/QianFanPropertiesTests.java | 6 +- .../SpringAiRetryAutoConfigurationIT.java | 5 +- .../retry/SpringAiRetryPropertiesTests.java | 5 +- .../StabilityAiAutoConfigurationIT.java | 10 +- .../StabilityAiImagePropertiesTests.java | 5 +- ...mersEmbeddingModelAutoConfigurationIT.java | 30 +- .../AzureVectorStoreAutoConfigurationIT.java | 35 +- ...ssandraVectorStoreAutoConfigurationIT.java | 55 +- .../CassandraVectorStorePropertiesTests.java | 5 +- .../ChromaVectorStoreAutoConfigurationIT.java | 8 +- ...osmosDBVectorStoreAutoConfigurationIT.java | 41 +- ...csearchVectorStoreAutoConfigurationIT.java | 37 +- ...GemFireVectorStoreAutoConfigurationIT.java | 71 +- .../GemFireVectorStorePropertiesTests.java | 4 +- ...naCloudVectorStoreAutoConfigurationIT.java | 42 +- .../HanaCloudVectorStorePropertiesTest.java | 5 +- .../MilvusVectorStoreAutoConfigurationIT.java | 33 +- ...DBAtlasVectorStoreAutoConfigurationIT.java | 90 +- .../Neo4jVectorStoreAutoConfigurationIT.java | 37 +- .../observation/ObservationTestUtil.java | 35 +- ...toreObservationAutoConfigurationTests.java | 16 +- ...nSearchVectorStoreAutoConfigurationIT.java | 42 +- ...nSearchVectorStoreAutoConfigurationIT.java | 39 +- .../OracleVectorStoreAutoConfigurationIT.java | 59 +- .../OracleVectorStorePropertiesTests.java | 10 +- .../PgVectorStoreAutoConfigurationIT.java | 65 +- .../PgVectorStorePropertiesTests.java | 9 +- ...ineconeVectorStoreAutoConfigurationIT.java | 43 +- .../PineconeVectorStorePropertiesTests.java | 10 +- .../QdrantVectorStoreAutoConfigurationIT.java | 55 +- ...ntVectorStoreCloudAutoConfigurationIT.java | 47 +- .../QdrantVectorStorePropertiesTests.java | 6 +- .../RedisVectorStoreAutoConfigurationIT.java | 36 +- .../RedisVectorStorePropertiesTests.java | 9 +- ...pesenseVectorStoreAutoConfigurationIT.java | 33 +- ...eaviateVectorStoreAutoConfigurationIT.java | 10 +- ...TextEmbeddingModelAutoConfigurationIT.java | 32 +- .../VertexAiGeminiAutoConfigurationIT.java | 13 +- .../tool/FunctionCallWithFunctionBeanIT.java | 26 +- .../FunctionCallWithFunctionWrapperIT.java | 16 +- .../FunctionCallWithPromptFunctionIT.java | 18 +- .../gemini/tool/MockWeatherService.java | 57 +- .../VertexAiPaLm2AutoConfigurationIT.java | 26 +- .../WatsonxAiAutoConfigurationTests.java | 8 +- .../zhipuai/ZhiPuAiAutoConfigurationIT.java | 22 +- .../zhipuai/ZhiPuAiPropertiesTests.java | 6 +- .../tool/FunctionCallbackInPromptIT.java | 24 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 34 +- .../tool/FunctionCallbackWrapperIT.java | 24 +- .../zhipuai/tool/MockWeatherService.java | 65 +- .../src/test/resources/oracle/initialize.sql | 16 + spring-ai-spring-boot-docker-compose/pom.xml | 16 + ...DockerComposeConnectionDetailsFactory.java | 5 +- .../connection/chroma/ChromaEnvironment.java | 5 +- ...DockerComposeConnectionDetailsFactory.java | 6 +- ...DockerComposeConnectionDetailsFactory.java | 5 +- ...DockerComposeConnectionDetailsFactory.java | 9 +- .../opensearch/OpenSearchEnvironment.java | 5 +- ...DockerComposeConnectionDetailsFactory.java | 5 +- .../connection/qdrant/QdrantEnvironment.java | 16 + ...DockerComposeConnectionDetailsFactory.java | 5 +- .../typesense/TypesenseEnvironment.java | 16 + ...DockerComposeConnectionDetailsFactory.java | 5 +- .../main/resources/META-INF/spring.factories | 16 + ...rComposeConnectionDetailsFactoryTests.java | 8 +- .../chroma/ChromaEnvironmentTests.java | 9 +- ...rComposeConnectionDetailsFactoryTests.java | 8 +- ...rComposeConnectionDetailsFactoryTests.java | 19 +- ...rComposeConnectionDetailsFactoryTests.java | 8 +- ...rComposeConnectionDetailsFactoryTests.java | 8 +- .../OpenSearchEnvironmentTests.java | 9 +- ...rComposeConnectionDetailsFactoryTests.java | 8 +- ...rComposeConnectionDetailsFactoryTests.java | 8 +- .../typesense/TypesenseEnvironmentTests.java | 9 +- ...rComposeConnectionDetailsFactoryTests.java | 8 +- ...AbstractDockerComposeIntegrationTests.java | 18 +- .../DisabledIfProcessUnavailable.java | 6 +- ...DisabledIfProcessUnavailableCondition.java | 15 +- .../DisabledIfProcessUnavailables.java | 6 +- .../spring-ai-starter-anthropic/pom.xml | 16 + .../pom.xml | 16 + .../pom.xml | 16 + .../spring-ai-starter-azure-openai/pom.xml | 16 + .../spring-ai-starter-azure-store/pom.xml | 16 + .../spring-ai-starter-bedrock-ai/pom.xml | 16 + .../spring-ai-starter-cassandra-store/pom.xml | 16 + .../spring-ai-starter-chroma-store/pom.xml | 16 + .../pom.xml | 16 + .../spring-ai-starter-gemfire-store/pom.xml | 16 + .../spring-ai-starter-hanadb-store/pom.xml | 16 + .../spring-ai-starter-huggingface/pom.xml | 16 + .../spring-ai-starter-milvus-store/pom.xml | 16 + .../spring-ai-starter-minimax/pom.xml | 16 + .../spring-ai-starter-mistral-ai/pom.xml | 16 + .../pom.xml | 16 + .../spring-ai-starter-moonshot/pom.xml | 16 + .../spring-ai-starter-neo4j-store/pom.xml | 16 + .../spring-ai-starter-oci-genai/pom.xml | 16 + .../spring-ai-starter-ollama/pom.xml | 16 + .../spring-ai-starter-openai/pom.xml | 16 + .../pom.xml | 16 + .../spring-ai-starter-oracle-store/pom.xml | 16 + .../spring-ai-starter-pgvector-store/pom.xml | 16 + .../spring-ai-starter-pinecone-store/pom.xml | 16 + .../pom.xml | 16 + .../spring-ai-starter-qdrant-store/pom.xml | 16 + .../spring-ai-starter-qianfan/pom.xml | 16 + .../spring-ai-starter-redis-store/pom.xml | 16 + .../spring-ai-starter-stability-ai/pom.xml | 16 + .../spring-ai-starter-transformers/pom.xml | 16 + .../spring-ai-starter-typesense-store/pom.xml | 16 + .../pom.xml | 16 + .../pom.xml | 16 + .../spring-ai-starter-vertex-ai-palm2/pom.xml | 16 + .../spring-ai-starter-watsonx-ai/pom.xml | 16 + .../spring-ai-starter-weaviate-store/pom.xml | 16 + .../spring-ai-starter-zhipuai/pom.xml | 16 + spring-ai-spring-boot-testcontainers/pom.xml | 16 + ...romaContainerConnectionDetailsFactory.java | 12 +- ...lvusContainerConnectionDetailsFactory.java | 8 +- ...ocalContainerConnectionDetailsFactory.java | 8 +- ...lamaContainerConnectionDetailsFactory.java | 8 +- ...archContainerConnectionDetailsFactory.java | 10 +- ...rantContainerConnectionDetailsFactory.java | 8 +- ...enseContainerConnectionDetailsFactory.java | 8 +- ...iateContainerConnectionDetailsFactory.java | 8 +- .../main/resources/META-INF/spring.factories | 16 + ...ContainerConnectionDetailsFactoryTest.java | 13 +- .../connection/chroma/ChromaImage.java | 5 +- ...ContainerConnectionDetailsFactoryTest.java | 13 +- ...ContainerConnectionDetailsFactoryTest.java | 13 +- ...ContainerConnectionDetailsFactoryTest.java | 30 +- .../connection/milvus/MilvusImage.java | 5 +- ...alContainerConnectionDetailsFactoryIt.java | 30 +- .../connection/mongo/MongoDbImage.java | 5 +- ...ContainerConnectionDetailsFactoryTest.java | 24 +- .../connection/ollama/OllamaImage.java | 5 +- ...ContainerConnectionDetailsFactoryTest.java | 36 +- .../opensearch/OpenSearchImage.java | 5 +- ...ContainerConnectionDetailsFactoryTest.java | 52 +- ...ithApiKeyConnectionDetailsFactoryTest.java | 52 +- .../connection/qdrant/QdrantImage.java | 5 +- ...ContainerConnectionDetailsFactoryTest.java | 39 +- .../connection/typesense/TypesenseImage.java | 5 +- ...ContainerConnectionDetailsFactoryTest.java | 48 +- .../connection/weaviate/WeaviateImage.java | 5 +- spring-ai-spring-cloud-bindings/pom.xml | 16 + .../ai/bindings/BindingsValidator.java | 4 +- .../ChromaBindingsPropertiesProcessor.java | 10 +- .../MistralAiBindingsPropertiesProcessor.java | 8 +- .../OllamaBindingsPropertiesProcessor.java | 8 +- .../OpenAiBindingsPropertiesProcessor.java | 8 +- .../TanzuBindingsPropertiesProcessor.java | 10 +- .../WeaviateBindingsPropertiesProcessor.java | 10 +- .../main/resources/META-INF/spring.factories | 16 + ...hromaBindingsPropertiesProcessorTests.java | 29 +- ...ralAiBindingsPropertiesProcessorTests.java | 25 +- ...llamaBindingsPropertiesProcessorTests.java | 23 +- ...penAiBindingsPropertiesProcessorTests.java | 25 +- ...TanzuBindingsPropertiesProcessorTests.java | 39 +- ...viateBindingsPropertiesProcessorTests.java | 27 +- spring-ai-test/pom.xml | 16 + .../ai/evaluation/BasicEvaluationTest.java | 26 +- src/checkstyle/checkstyle-header.txt | 17 + src/checkstyle/checkstyle-suppressions.xml | 32 + src/checkstyle/checkstyle.xml | 185 ++ .../spring-ai-azure-cosmos-db-store/pom.xml | 16 + .../CosmosDBFilterExpressionConverter.java | 12 +- .../ai/vectorstore/CosmosDBVectorStore.java | 49 +- .../CosmosDBVectorStoreConfig.java | 12 +- .../ai/vectorstore/CosmosDBVectorStoreIT.java | 41 +- .../src/test/resources/application.properties | 16 + vector-stores/spring-ai-azure-store/pom.xml | 16 + ...zureAiSearchFilterExpressionConverter.java | 5 +- .../vectorstore/azure/AzureVectorStore.java | 130 +- ...iSearchFilterExpressionConverterTests.java | 7 +- .../vectorstore/azure/AzureVectorStoreIT.java | 51 +- .../azure/AzureVectorStoreObservationIT.java | 32 +- .../spring-ai-cassandra-store/pom.xml | 16 + .../ai/cassandra/SchemaUtil.java | 13 +- .../ai/chat/memory/CassandraChatMemory.java | 37 +- .../memory/CassandraChatMemoryConfig.java | 250 +-- .../CassandraFilterExpressionConverter.java | 53 +- .../ai/vectorstore/CassandraVectorStore.java | 80 +- .../CassandraVectorStoreConfig.java | 480 ++--- .../springframework/ai/CassandraImage.java | 5 +- .../ai/chat/memory/CassandraChatMemoryIT.java | 7 +- ...ssandraFilterExpressionConverterTests.java | 27 +- .../CassandraRichSchemaVectorStoreIT.java | 182 +- .../vectorstore/CassandraVectorStoreIT.java | 61 +- .../CassandraVectorStoreObservationIT.java | 44 +- .../vectorstore/WikiVectorStoreExample.java | 5 +- .../test/resources/test_wiki_full_schema.cql | 16 + .../resources/test_wiki_partial_0_schema.cql | 16 + .../resources/test_wiki_partial_1_schema.cql | 16 + .../resources/test_wiki_partial_2_schema.cql | 16 + .../resources/test_wiki_partial_3_schema.cql | 16 + .../resources/test_wiki_partial_4_schema.cql | 16 + vector-stores/spring-ai-chroma-store/pom.xml | 16 + .../springframework/ai/chroma/ChromaApi.java | 370 ++-- .../ChromaFilterExpressionConverter.java | 5 +- .../ai/vectorstore/ChromaVectorStore.java | 12 +- .../org/springframework/ai/ChromaImage.java | 5 +- .../ai/chroma/ChromaApiIT.java | 75 +- .../vectorstore/BasicAuthChromaWhereIT.java | 20 +- .../ai/vectorstore/ChromaVectorStoreIT.java | 44 +- .../ChromaVectorStoreObservationIT.java | 25 +- .../TokenSecuredChromaWhereIT.java | 20 +- .../spring-ai-elasticsearch-store/pom.xml | 16 + ...archAiSearchFilterExpressionConverter.java | 15 +- .../vectorstore/ElasticsearchVectorStore.java | 55 +- .../ElasticsearchVectorStoreOptions.java | 11 +- .../ai/vectorstore/SimilarityFunction.java | 16 + ...AiSearchFilterExpressionConverterTest.java | 34 +- .../ai/vectorstore/ElasticsearchImage.java | 5 +- .../ElasticsearchVectorStoreIT.java | 21 +- ...ElasticsearchVectorStoreObservationIT.java | 46 +- vector-stores/spring-ai-gemfire-store/pom.xml | 16 + .../ai/vectorstore/GemFireVectorStore.java | 21 +- .../ai/vectorstore/GemFireImage.java | 5 +- .../ai/vectorstore/GemFireVectorStoreIT.java | 52 +- .../GemFireVectorStoreObservationIT.java | 49 +- vector-stores/spring-ai-hanadb-store/pom.xml | 16 + .../ai/vectorstore/HanaCloudVectorStore.java | 13 +- .../HanaCloudVectorStoreConfig.java | 13 +- .../ai/vectorstore/HanaVectorEntity.java | 7 +- .../ai/vectorstore/HanaVectorRepository.java | 5 +- .../ai/vectorstore/CricketWorldCup.java | 7 +- .../CricketWorldCupHanaController.java | 26 +- .../CricketWorldCupRepository.java | 18 +- .../vectorstore/HanaCloudVectorStoreIT.java | 9 +- .../HanaVectorStoreObservationIT.java | 25 +- .../src/test/resources/application.properties | 16 + vector-stores/spring-ai-milvus-store/pom.xml | 16 + .../MilvusFilterExpressionConverter.java | 5 +- .../ai/vectorstore/MilvusVectorStore.java | 320 ++-- .../MilvusEmbeddingDimensionsTests.java | 25 +- .../MilvusFilterExpressionConverterTests.java | 23 +- .../ai/vectorstore/MilvusImage.java | 5 +- .../ai/vectorstore/MilvusVectorStoreIT.java | 217 +-- .../MilvusVectorStoreObservationIT.java | 33 +- .../spring-ai-mongodb-atlas-store/pom.xml | 16 + ...MongoDBAtlasFilterExpressionConverter.java | 5 +- .../vectorstore/MongoDBAtlasVectorStore.java | 37 +- .../vectorstore/VectorSearchAggregation.java | 21 +- .../MongoDBAtlasFilterConverterTest.java | 23 +- .../MongoDBAtlasVectorStoreIT.java | 36 +- .../ai/vectorstore/MongoDbImage.java | 5 +- .../MongoDbVectorStoreObservationIT.java | 33 +- .../VectorSearchAggregationTest.java | 12 +- vector-stores/spring-ai-neo4j-store/pom.xml | 16 + .../ai/vectorstore/Neo4jVectorStore.java | 418 ++--- .../Neo4jVectorFilterExpressionConverter.java | 5 +- .../ai/vectorstore/Neo4jImage.java | 5 +- .../ai/vectorstore/Neo4jVectorStoreIT.java | 15 +- .../Neo4jVectorStoreObservationIT.java | 25 +- ...jVectorFilterExpressionConverterTests.java | 25 +- .../spring-ai-opensearch-store/pom.xml | 16 + ...archAiSearchFilterExpressionConverter.java | 15 +- .../ai/vectorstore/OpenSearchVectorStore.java | 32 +- ...AiSearchFilterExpressionConverterTest.java | 42 +- .../ai/vectorstore/OpenSearchImage.java | 5 +- .../vectorstore/OpenSearchVectorStoreIT.java | 43 +- .../OpenSearchVectorStoreObservationIT.java | 32 +- vector-stores/spring-ai-oracle-store/pom.xml | 16 + .../ai/vectorstore/OracleVectorStore.java | 394 ++-- .../SqlJsonPathFilterExpressionConverter.java | 16 + .../ai/vectorstore/OracleImage.java | 5 +- .../ai/vectorstore/OracleVectorStoreIT.java | 196 +- .../OracleVectorStoreObservationIT.java | 29 +- ...sonPathFilterExpressionConverterTests.java | 21 +- .../src/test/resources/initialize.sql | 16 + .../spring-ai-pgvector-store/pom.xml | 16 + .../PgVectorFilterExpressionConverter.java | 8 +- .../vectorstore/PgVectorSchemaValidator.java | 10 +- .../ai/vectorstore/PgVectorStore.java | 86 +- .../PgVectorEmbeddingDimensionsTests.java | 21 +- ...gVectorFilterExpressionConverterTests.java | 33 +- .../ai/vectorstore/PgVectorImage.java | 5 +- .../PgVectorStoreCustomNamesIT.java | 48 +- .../ai/vectorstore/PgVectorStoreIT.java | 115 +- .../PgVectorStoreObservationIT.java | 48 +- .../ai/vectorstore/PgVectorStoreTests.java | 19 +- .../PgVectorStoreWithChatMemoryAdvisorIT.java | 105 +- .../spring-ai-pinecone-store/pom.xml | 16 + .../ai/vectorstore/PineconeVectorStore.java | 382 ++-- .../vectorstore/PineconeVectorStoreHints.java | 20 +- .../ai/vectorstore/PineconeVectorStoreIT.java | 33 +- .../PineconeVectorStoreObservationIT.java | 29 +- vector-stores/spring-ai-qdrant-store/pom.xml | 16 + .../QdrantFilterExpressionConverter.java | 29 +- .../qdrant/QdrantObjectFactory.java | 5 +- .../qdrant/QdrantValueFactory.java | 5 +- .../vectorstore/qdrant/QdrantVectorStore.java | 168 +- .../ai/vectorstore/qdrant/QdrantImage.java | 5 +- .../qdrant/QdrantVectorStoreIT.java | 37 +- .../QdrantVectorStoreObservationIT.java | 39 +- vector-stores/spring-ai-redis-store/pom.xml | 16 + .../RedisFilterExpressionConverter.java | 10 +- .../ai/vectorstore/RedisVectorStore.java | 364 ++-- .../RedisFilterExpressionConverterTests.java | 28 +- .../ai/vectorstore/RedisVectorStoreIT.java | 35 +- .../RedisVectorStoreObservationIT.java | 38 +- .../spring-ai-typesense-store/pom.xml | 16 + .../TypesenseFilterExpressionConverter.java | 18 +- .../ai/vectorstore/TypesenseVectorStore.java | 174 +- .../ai/vectorstore/TypesenseImage.java | 5 +- .../vectorstore/TypesenseVectorStoreIT.java | 47 +- .../TypesenseVectorStoreObservationIT.java | 31 +- .../spring-ai-weaviate-store/pom.xml | 16 + .../WeaviateFilterExpressionConverter.java | 11 +- .../ai/vectorstore/WeaviateVectorStore.java | 357 ++-- ...eaviateFilterExpressionConverterTests.java | 5 +- .../ai/vectorstore/WeaviateImage.java | 5 +- .../ai/vectorstore/WeaviateVectorStoreIT.java | 41 +- .../WeaviateVectorStoreObservationIT.java | 35 +- 1412 files changed, 25622 insertions(+), 20588 deletions(-) create mode 100644 src/checkstyle/checkstyle-header.txt create mode 100644 src/checkstyle/checkstyle-suppressions.xml create mode 100644 src/checkstyle/checkstyle.xml diff --git a/.devcontainer/scripts/onCreateCommand.sh b/.devcontainer/scripts/onCreateCommand.sh index e4aaf8ace11..eba12a3577a 100755 --- a/.devcontainer/scripts/onCreateCommand.sh +++ b/.devcontainer/scripts/onCreateCommand.sh @@ -1,5 +1,21 @@ #!/bin/bash +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + set -x az extension add --name spring diff --git a/.editorconfig b/.editorconfig index 3e1127a6d67..8cbc8ccb43e 100644 --- a/.editorconfig +++ b/.editorconfig @@ -8,3 +8,5 @@ indent_style = tab indent_size = 4 continuation_indent_size = 8 end_of_line = lf + +insert_final_newline = true diff --git a/.mvn/extensions.xml b/.mvn/extensions.xml index c7e6507acce..31675c58918 100644 --- a/.mvn/extensions.xml +++ b/.mvn/extensions.xml @@ -1,4 +1,20 @@ + + fr.jcgay.maven diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties index dc3affce3dd..da1385eeb9b 100644 --- a/.mvn/wrapper/maven-wrapper.properties +++ b/.mvn/wrapper/maven-wrapper.properties @@ -1,18 +1,17 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.8.6/apache-maven-3.8.6-bin.zip wrapperUrl=https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.1.1/maven-wrapper-3.1.1.jar diff --git a/document-readers/markdown-reader/pom.xml b/document-readers/markdown-reader/pom.xml index 5922ea2b4ed..9ad6aa6a152 100644 --- a/document-readers/markdown-reader/pom.xml +++ b/document-readers/markdown-reader/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/MarkdownDocumentReader.java b/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/MarkdownDocumentReader.java index 7ed8aa6b548..19ebed9cad6 100644 --- a/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/MarkdownDocumentReader.java +++ b/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/MarkdownDocumentReader.java @@ -1,18 +1,45 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.reader.markdown; -import org.commonmark.node.*; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.List; + +import org.commonmark.node.AbstractVisitor; +import org.commonmark.node.BlockQuote; +import org.commonmark.node.Code; +import org.commonmark.node.FencedCodeBlock; +import org.commonmark.node.HardLineBreak; +import org.commonmark.node.Heading; +import org.commonmark.node.ListItem; +import org.commonmark.node.Node; +import org.commonmark.node.SoftLineBreak; +import org.commonmark.node.Text; +import org.commonmark.node.ThematicBreak; import org.commonmark.parser.Parser; + import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentReader; import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; -import java.io.IOException; -import java.io.InputStreamReader; -import java.util.ArrayList; -import java.util.List; - /** * Reads the given Markdown resource and groups headers, paragraphs, or text divided by * horizontal lines (depending on the @@ -58,10 +85,10 @@ public MarkdownDocumentReader(Resource markdownResource, MarkdownDocumentReaderC */ @Override public List get() { - try (var input = markdownResource.getInputStream()) { - Node node = parser.parseReader(new InputStreamReader(input)); + try (var input = this.markdownResource.getInputStream()) { + Node node = this.parser.parseReader(new InputStreamReader(input)); - DocumentVisitor documentVisitor = new DocumentVisitor(config); + DocumentVisitor documentVisitor = new DocumentVisitor(this.config); node.accept(documentVisitor); return documentVisitor.getDocuments(); @@ -90,7 +117,7 @@ public DocumentVisitor(MarkdownDocumentReaderConfig config) { @Override public void visit(org.commonmark.node.Document document) { - currentDocumentBuilder = Document.builder(); + this.currentDocumentBuilder = Document.builder(); super.visit(document); } @@ -102,7 +129,7 @@ public void visit(Heading heading) { @Override public void visit(ThematicBreak thematicBreak) { - if (config.horizontalRuleCreateDocument) { + if (this.config.horizontalRuleCreateDocument) { buildAndFlush(); } super.visit(thematicBreak); @@ -128,32 +155,32 @@ public void visit(ListItem listItem) { @Override public void visit(BlockQuote blockQuote) { - if (!config.includeBlockquote) { + if (!this.config.includeBlockquote) { buildAndFlush(); } translateLineBreakToSpace(); - currentDocumentBuilder.withMetadata("category", "blockquote"); + this.currentDocumentBuilder.withMetadata("category", "blockquote"); super.visit(blockQuote); } @Override public void visit(Code code) { - currentParagraphs.add(code.getLiteral()); - currentDocumentBuilder.withMetadata("category", "code_inline"); + this.currentParagraphs.add(code.getLiteral()); + this.currentDocumentBuilder.withMetadata("category", "code_inline"); super.visit(code); } @Override public void visit(FencedCodeBlock fencedCodeBlock) { - if (!config.includeCodeBlock) { + if (!this.config.includeCodeBlock) { buildAndFlush(); } translateLineBreakToSpace(); - currentParagraphs.add(fencedCodeBlock.getLiteral()); - currentDocumentBuilder.withMetadata("category", "code_block"); - currentDocumentBuilder.withMetadata("lang", fencedCodeBlock.getInfo()); + this.currentParagraphs.add(fencedCodeBlock.getLiteral()); + this.currentDocumentBuilder.withMetadata("category", "code_block"); + this.currentDocumentBuilder.withMetadata("lang", fencedCodeBlock.getInfo()); buildAndFlush(); @@ -163,11 +190,11 @@ public void visit(FencedCodeBlock fencedCodeBlock) { @Override public void visit(Text text) { if (text.getParent() instanceof Heading heading) { - currentDocumentBuilder.withMetadata("category", "header_%d".formatted(heading.getLevel())) + this.currentDocumentBuilder.withMetadata("category", "header_%d".formatted(heading.getLevel())) .withMetadata("title", text.getLiteral()); } else { - currentParagraphs.add(text.getLiteral()); + this.currentParagraphs.add(text.getLiteral()); } super.visit(text); @@ -176,29 +203,29 @@ public void visit(Text text) { public List getDocuments() { buildAndFlush(); - return documents; + return this.documents; } private void buildAndFlush() { - if (!currentParagraphs.isEmpty()) { - String content = String.join("", currentParagraphs); + if (!this.currentParagraphs.isEmpty()) { + String content = String.join("", this.currentParagraphs); - Document.Builder builder = currentDocumentBuilder.withContent(content); + Document.Builder builder = this.currentDocumentBuilder.withContent(content); - config.additionalMetadata.forEach(builder::withMetadata); + this.config.additionalMetadata.forEach(builder::withMetadata); Document document = builder.build(); - documents.add(document); + this.documents.add(document); - currentParagraphs.clear(); + this.currentParagraphs.clear(); } - currentDocumentBuilder = Document.builder(); + this.currentDocumentBuilder = Document.builder(); } private void translateLineBreakToSpace() { - if (!currentParagraphs.isEmpty()) { - currentParagraphs.add(" "); + if (!this.currentParagraphs.isEmpty()) { + this.currentParagraphs.add(" "); } } diff --git a/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/config/MarkdownDocumentReaderConfig.java b/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/config/MarkdownDocumentReaderConfig.java index d5ad3ec58ce..c22c573f0e8 100644 --- a/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/config/MarkdownDocumentReaderConfig.java +++ b/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/config/MarkdownDocumentReaderConfig.java @@ -1,12 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.reader.markdown.config; +import java.util.HashMap; +import java.util.Map; + import org.springframework.ai.document.Document; import org.springframework.ai.reader.markdown.MarkdownDocumentReader; import org.springframework.util.Assert; -import java.util.HashMap; -import java.util.Map; - /** * Common configuration for the {@link MarkdownDocumentReader}. * @@ -23,10 +39,10 @@ public class MarkdownDocumentReaderConfig { public final Map additionalMetadata; public MarkdownDocumentReaderConfig(Builder builder) { - horizontalRuleCreateDocument = builder.horizontalRuleCreateDocument; - includeCodeBlock = builder.includeCodeBlock; - includeBlockquote = builder.includeBlockquote; - additionalMetadata = builder.additionalMetadata; + this.horizontalRuleCreateDocument = builder.horizontalRuleCreateDocument; + this.includeCodeBlock = builder.includeCodeBlock; + this.includeBlockquote = builder.includeBlockquote; + this.additionalMetadata = builder.additionalMetadata; } /** diff --git a/document-readers/markdown-reader/src/test/java/org/springframework/ai/reader/markdown/MarkdownDocumentReaderTest.java b/document-readers/markdown-reader/src/test/java/org/springframework/ai/reader/markdown/MarkdownDocumentReaderTest.java index 739dbbd709b..69d3babe54f 100644 --- a/document-readers/markdown-reader/src/test/java/org/springframework/ai/reader/markdown/MarkdownDocumentReaderTest.java +++ b/document-readers/markdown-reader/src/test/java/org/springframework/ai/reader/markdown/MarkdownDocumentReaderTest.java @@ -1,12 +1,29 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.reader.markdown; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.groups.Tuple.tuple; diff --git a/document-readers/pdf-reader/pom.xml b/document-readers/pdf-reader/pom.xml index c870c9176c2..eace8bd6d2b 100644 --- a/document-readers/pdf-reader/pom.xml +++ b/document-readers/pdf-reader/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java index d1e95cd5057..11fb9933030 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf; -import java.awt.Rectangle; +import java.awt.*; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -24,9 +25,9 @@ import org.apache.pdfbox.pdfparser.PDFParser; import org.apache.pdfbox.pdmodel.PDDocument; import org.apache.pdfbox.pdmodel.PDPage; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentReader; import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig; @@ -46,22 +47,22 @@ */ public class PagePdfDocumentReader implements DocumentReader { - private final Logger logger = LoggerFactory.getLogger(getClass()); - - private static final String PDF_PAGE_REGION = "pdfPageRegion"; - public static final String METADATA_START_PAGE_NUMBER = "page_number"; public static final String METADATA_END_PAGE_NUMBER = "end_page_number"; public static final String METADATA_FILE_NAME = "file_name"; + private static final String PDF_PAGE_REGION = "pdfPageRegion"; + protected final PDDocument document; - private PdfDocumentReaderConfig config; + private final Logger logger = LoggerFactory.getLogger(getClass()); protected String resourceFileName; + private PdfDocumentReaderConfig config; + public PagePdfDocumentReader(String resourceUrl) { this(new DefaultResourceLoader().getResource(resourceUrl)); } @@ -103,15 +104,15 @@ public List get() { int totalPages = this.document.getDocumentCatalog().getPages().getCount(); int logFrequency = totalPages > 10 ? totalPages / 10 : 1; // if less than 10 - // pages, print - // each iteration + // pages, print + // each iteration int counter = 0; PDPage lastPage = this.document.getDocumentCatalog().getPages().iterator().next(); for (PDPage page : this.document.getDocumentCatalog().getPages()) { lastPage = page; if (counter % logFrequency == 0 && counter / logFrequency < 10) { - logger.info("Processing PDF page: {}", (counter + 1)); + this.logger.info("Processing PDF page: {}", (counter + 1)); } counter++; @@ -153,7 +154,7 @@ public List get() { readDocuments.add(toDocument(lastPage, pageTextGroupList.stream().collect(Collectors.joining()), startPageNumber, pageNumber)); } - logger.info("Processing {} pages", totalPages); + this.logger.info("Processing {} pages", totalPages); return readDocuments; } diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java index 9f5d055305e..a5943d45d36 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf; -import java.awt.Rectangle; +import java.awt.*; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import org.apache.pdfbox.pdfparser.PDFParser; import org.apache.pdfbox.pdmodel.PDDocument; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentReader; import org.springframework.ai.reader.pdf.config.ParagraphManager; @@ -48,8 +49,6 @@ */ public class ParagraphPdfDocumentReader implements DocumentReader { - private final Logger logger = LoggerFactory.getLogger(getClass()); - // Constants for metadata keys private static final String METADATA_START_PAGE = "page_number"; @@ -61,14 +60,16 @@ public class ParagraphPdfDocumentReader implements DocumentReader { private static final String METADATA_FILE_NAME = "file_name"; - private final ParagraphManager paragraphTextExtractor; - protected final PDDocument document; - private PdfDocumentReaderConfig config; + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final ParagraphManager paragraphTextExtractor; protected String resourceFileName; + private PdfDocumentReaderConfig config; + /** * Constructs a ParagraphPdfDocumentReader using a resource URL. * @param resourceUrl The URL of the PDF resource. @@ -132,7 +133,7 @@ public List get() { List documents = new ArrayList<>(paragraphs.size()); if (!CollectionUtils.isEmpty(paragraphs)) { - logger.info("Start processing paragraphs from PDF"); + this.logger.info("Start processing paragraphs from PDF"); Iterator itr = paragraphs.iterator(); var current = itr.next(); @@ -151,7 +152,7 @@ public List get() { } } } - logger.info("End processing paragraphs from PDF"); + this.logger.info("End processing paragraphs from PDF"); return documents; } diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java index 0e2c7fbe975..ae5b8588fed 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.aot; +import java.io.IOException; +import java.util.Set; + import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; -import java.io.IOException; -import java.util.Set; - /** * The PdfReaderRuntimeHints class is responsible for registering runtime hints for PDFBox * resources. diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/ParagraphManager.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/ParagraphManager.java index 01188074347..555b23fa034 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/ParagraphManager.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/ParagraphManager.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.config; import java.io.IOException; @@ -39,34 +40,6 @@ */ public class ParagraphManager { - /** - * Represents a document paragraph metadata and hierarchy. - * - * @param parent Parent paragraph that will contain a children paragraphs. - * @param title Paragraph title as it appears in the PDF document. - * @param level The TOC deepness level for this paragraph. The root is at level 0. - * @param startPageNumber The page number in the PDF where this paragraph begins. - * @param endPageNumber The page number in the PDF where this paragraph ends. - * @param children Sub-paragraphs for this paragraph. - */ - public record Paragraph(Paragraph parent, String title, int level, int startPageNumber, int endPageNumber, - int position, List children) { - - public Paragraph(Paragraph parent, String title, int level, int startPageNumber, int endPageNumber, - int position) { - this(parent, title, level, startPageNumber, endPageNumber, position, new ArrayList<>()); - } - - @Override - public String toString() { - String indent = (level < 0) ? "" : new String(new char[level * 2]).replace('\0', ' '); - - return indent + " " + level + ") " + title + " [" + startPageNumber + "," + endPageNumber + "], children = " - + children.size() + ", pos = " + position; - } - - } - /** * Root of the paragraphs tree. */ @@ -90,7 +63,7 @@ public ParagraphManager(PDDocument document) { new Paragraph(null, "root", -1, 1, this.document.getNumberOfPages(), 0), this.document.getDocumentCatalog().getDocumentOutline(), 0); - printParagraph(rootParagraph, System.out); + printParagraph(this.rootParagraph, System.out); } catch (Exception e) { throw new RuntimeException(e); @@ -203,4 +176,32 @@ else if (paragraph.level() == level) { return resultList; } + /** + * Represents a document paragraph metadata and hierarchy. + * + * @param parent Parent paragraph that will contain a children paragraphs. + * @param title Paragraph title as it appears in the PDF document. + * @param level The TOC deepness level for this paragraph. The root is at level 0. + * @param startPageNumber The page number in the PDF where this paragraph begins. + * @param endPageNumber The page number in the PDF where this paragraph ends. + * @param children Sub-paragraphs for this paragraph. + */ + public record Paragraph(Paragraph parent, String title, int level, int startPageNumber, int endPageNumber, + int position, List children) { + + public Paragraph(Paragraph parent, String title, int level, int startPageNumber, int endPageNumber, + int position) { + this(parent, title, level, startPageNumber, endPageNumber, position, new ArrayList<>()); + } + + @Override + public String toString() { + String indent = (this.level < 0) ? "" : new String(new char[this.level * 2]).replace('\0', ' '); + + return indent + " " + this.level + ") " + this.title + " [" + this.startPageNumber + "," + + this.endPageNumber + "], children = " + this.children.size() + ", pos = " + this.position; + } + + } + } diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java index 5a375b3d42e..b80ff8e9bb3 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.config; import org.springframework.ai.reader.ExtractedTextFormatter; @@ -40,6 +41,14 @@ public class PdfDocumentReaderConfig { public final ExtractedTextFormatter pageExtractedTextFormatter; + private PdfDocumentReaderConfig(PdfDocumentReaderConfig.Builder builder) { + this.pagesPerDocument = builder.pagesPerDocument; + this.pageBottomMargin = builder.pageBottomMargin; + this.pageTopMargin = builder.pageTopMargin; + this.pageExtractedTextFormatter = builder.pageExtractedTextFormatter; + this.reversedParagraphPosition = builder.reversedParagraphPosition; + } + /** * Start building a new configuration. * @return The entry point for creating a new configuration. @@ -56,14 +65,6 @@ public static PdfDocumentReaderConfig defaultConfig() { return builder().build(); } - private PdfDocumentReaderConfig(PdfDocumentReaderConfig.Builder builder) { - this.pagesPerDocument = builder.pagesPerDocument; - this.pageBottomMargin = builder.pageBottomMargin; - this.pageTopMargin = builder.pageTopMargin; - this.pageExtractedTextFormatter = builder.pageExtractedTextFormatter; - this.reversedParagraphPosition = builder.reversedParagraphPosition; - } - public static class Builder { private int pagesPerDocument = 1; diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java index 80e35acb399..ea1980ff667 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -180,8 +180,9 @@ private int getNumberOfNewLinesFromPreviousTextPosition(final TextPosition textP double height = textPosition.getHeight(); int numberOfLines = (int) (Math.floor(textYPosition - previousTextYPosition) / height); numberOfLines = Math.max(1, numberOfLines - 1); // exclude current new line - if (DEBUG) + if (DEBUG) { System.out.println(height + " " + numberOfLines); + } return numberOfLines; } else { @@ -191,7 +192,7 @@ private int getNumberOfNewLinesFromPreviousTextPosition(final TextPosition textP private TextLine addNewLine() { TextLine textLine = new TextLine(this.getCurrentPageWidth()); - textLineList.add(textLine); + this.textLineList.add(textLine); return textLine; } @@ -248,7 +249,7 @@ public int getLineLength() { } public String getLine() { - return line; + return this.line; } private int computeIndexForCharacter(final Character character) { @@ -313,7 +314,7 @@ private boolean indexIsInBounds(int index) { private void completeLineWithSpaces() { for (int i = 0; i < this.getLineLength(); ++i) { - line += SPACE_CHARACTER; + this.line += SPACE_CHARACTER; } } @@ -350,8 +351,9 @@ public Character(char characterValue, int index, boolean isCharacterPartOfPrevio this.isFirstCharacterOfAWord = isFirstCharacterOfAWord; this.isCharacterAtTheBeginningOfNewLine = isCharacterAtTheBeginningOfNewLine; this.isCharacterCloseToPreviousWord = isCharacterPartOfASentence; - if (ForkPDFLayoutTextStripper.DEBUG) + if (ForkPDFLayoutTextStripper.DEBUG) { System.out.println(this.toString()); + } } public char getCharacterValue() { @@ -384,14 +386,14 @@ public boolean isCharacterCloseToPreviousWord() { public String toString() { String toString = ""; - toString += index; + toString += this.index; toString += " "; - toString += characterValue; - toString += " isCharacterPartOfPreviousWord=" + isCharacterPartOfPreviousWord; - toString += " isFirstCharacterOfAWord=" + isFirstCharacterOfAWord; - toString += " isCharacterAtTheBeginningOfNewLine=" + isCharacterAtTheBeginningOfNewLine; - toString += " isCharacterPartOfASentence=" + isCharacterCloseToPreviousWord; - toString += " isCharacterCloseToPreviousWord=" + isCharacterCloseToPreviousWord; + toString += this.characterValue; + toString += " isCharacterPartOfPreviousWord=" + this.isCharacterPartOfPreviousWord; + toString += " isFirstCharacterOfAWord=" + this.isFirstCharacterOfAWord; + toString += " isCharacterAtTheBeginningOfNewLine=" + this.isCharacterAtTheBeginningOfNewLine; + toString += " isCharacterPartOfASentence=" + this.isCharacterCloseToPreviousWord; + toString += " isCharacterCloseToPreviousWord=" + this.isCharacterCloseToPreviousWord; return toString; } @@ -424,12 +426,12 @@ public Character createCharacterFromTextPosition(final TextPosition textPosition this.isCharacterCloseToPreviousWord = this.isCharacterCloseToPreviousWord(textPosition); char character = this.getCharacterFromTextPosition(textPosition); int index = (int) textPosition.getX() / ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT; - return new Character(character, index, isCharacterPartOfPreviousWord, isFirstCharacterOfAWord, - isCharacterAtTheBeginningOfNewLine, isCharacterCloseToPreviousWord); + return new Character(character, index, this.isCharacterPartOfPreviousWord, this.isFirstCharacterOfAWord, + this.isCharacterAtTheBeginningOfNewLine, this.isCharacterCloseToPreviousWord); } private boolean isCharacterAtTheBeginningOfNewLine(final TextPosition textPosition) { - if (!firstCharacterOfLineFound) { + if (!this.firstCharacterOfLineFound) { return true; } TextPosition previousTextPosition = this.getPreviousTextPosition(); @@ -438,18 +440,18 @@ private boolean isCharacterAtTheBeginningOfNewLine(final TextPosition textPositi } private boolean isFirstCharacterOfAWord(final TextPosition textPosition) { - if (!firstCharacterOfLineFound) { + if (!this.firstCharacterOfLineFound) { return true; } - double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(previousTextPosition, textPosition); + double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(this.previousTextPosition, textPosition); return (numberOfSpaces > 1) || this.isCharacterAtTheBeginningOfNewLine(textPosition); } private boolean isCharacterCloseToPreviousWord(final TextPosition textPosition) { - if (!firstCharacterOfLineFound) { + if (!this.firstCharacterOfLineFound) { return false; } - double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(previousTextPosition, textPosition); + double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(this.previousTextPosition, textPosition); return (numberOfSpaces > 1 && numberOfSpaces <= ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT); } @@ -485,4 +487,4 @@ private void setPreviousTextPosition(final TextPosition previousTextPosition) { this.previousTextPosition = previousTextPosition; } -} \ No newline at end of file +} diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/PDFLayoutTextStripperByArea.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/PDFLayoutTextStripperByArea.java index 44bcb511a79..a5d39db89a7 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/PDFLayoutTextStripperByArea.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/PDFLayoutTextStripperByArea.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.layout; import java.awt.geom.Rectangle2D; @@ -70,8 +71,8 @@ public final void setShouldSeparateByBeads(boolean aShouldSeparateByBeads) { * java coordinates (y == 0 is top), not PDF coordinates (y == 0 is bottom). */ public void addRegion(String regionName, Rectangle2D rect) { - regions.add(regionName); - regionArea.put(regionName, rect); + this.regions.add(regionName); + this.regionArea.put(regionName, rect); } /** @@ -80,8 +81,8 @@ public void addRegion(String regionName, Rectangle2D rect) { * @param regionName The name of the region to delete. */ public void removeRegion(String regionName) { - regions.remove(regionName); - regionArea.remove(regionName); + this.regions.remove(regionName); + this.regionArea.remove(regionName); } /** @@ -89,7 +90,7 @@ public void removeRegion(String regionName) { * @return A list of java.lang.String objects to identify the region names. */ public List getRegions() { - return regions; + return this.regions; } /** @@ -98,7 +99,7 @@ public List getRegions() { * @return The text that was identified in that region. */ public String getTextForRegion(String regionName) { - StringWriter text = regionText.get(regionName); + StringWriter text = this.regionText.get(regionName); return text.toString(); } @@ -108,14 +109,14 @@ public String getTextForRegion(String regionName) { * @throws IOException If there is an error while extracting text. */ public void extractRegions(PDPage page) throws IOException { - for (String regionName : regions) { + for (String regionName : this.regions) { setStartPage(getCurrentPageNo()); setEndPage(getCurrentPageNo()); // reset the stored text for the region so this class can be reused. ArrayList> regionCharactersByArticle = new ArrayList>(); regionCharactersByArticle.add(new ArrayList()); - regionCharacterList.put(regionName, regionCharactersByArticle); - regionText.put(regionName, new StringWriter()); + this.regionCharacterList.put(regionName, regionCharactersByArticle); + this.regionText.put(regionName, new StringWriter()); } if (page.hasContents()) { @@ -128,10 +129,10 @@ public void extractRegions(PDPage page) throws IOException { */ @Override protected void processTextPosition(TextPosition text) { - for (Map.Entry regionAreaEntry : regionArea.entrySet()) { + for (Map.Entry regionAreaEntry : this.regionArea.entrySet()) { Rectangle2D rect = regionAreaEntry.getValue(); if (rect.contains(text.getX(), text.getY())) { - charactersByArticle = regionCharacterList.get(regionAreaEntry.getKey()); + this.charactersByArticle = this.regionCharacterList.get(regionAreaEntry.getKey()); super.processTextPosition(text); } } @@ -143,9 +144,9 @@ protected void processTextPosition(TextPosition text) { */ @Override protected void writePage() throws IOException { - for (String region : regionArea.keySet()) { - charactersByArticle = regionCharacterList.get(region); - output = regionText.get(region); + for (String region : this.regionArea.keySet()) { + this.charactersByArticle = this.regionCharacterList.get(region); + this.output = this.regionText.get(region); super.writePage(); } } diff --git a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/PagePdfDocumentReaderTests.java b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/PagePdfDocumentReaderTests.java index f42d7ef3d2a..71c230fafe8 100644 --- a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/PagePdfDocumentReaderTests.java +++ b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/PagePdfDocumentReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf; import java.util.List; diff --git a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java index eec22054d41..5b45f14de8a 100644 --- a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java +++ b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf; import org.junit.jupiter.api.Test; diff --git a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHintsTests.java b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHintsTests.java index b7e0cd12e44..c409abaa211 100644 --- a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHintsTests.java +++ b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.aot; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; + import org.springframework.aot.hint.RuntimeHints; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.resource; diff --git a/document-readers/tika-reader/pom.xml b/document-readers/tika-reader/pom.xml index 35abb98c6b5..59297edd722 100644 --- a/document-readers/tika-reader/pom.xml +++ b/document-readers/tika-reader/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/document-readers/tika-reader/src/main/java/org/springframework/ai/reader/tika/TikaDocumentReader.java b/document-readers/tika-reader/src/main/java/org/springframework/ai/reader/tika/TikaDocumentReader.java index f004cd197e8..1619e2bc92e 100644 --- a/document-readers/tika-reader/src/main/java/org/springframework/ai/reader/tika/TikaDocumentReader.java +++ b/document-readers/tika-reader/src/main/java/org/springframework/ai/reader/tika/TikaDocumentReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.tika; import java.io.IOException; diff --git a/document-readers/tika-reader/src/test/java/org/springframework/ai/reader/tika/TikaDocumentReaderTests.java b/document-readers/tika-reader/src/test/java/org/springframework/ai/reader/tika/TikaDocumentReaderTests.java index 84a167ef1ee..5ae1e7a7d5d 100644 --- a/document-readers/tika-reader/src/test/java/org/springframework/ai/reader/tika/TikaDocumentReaderTests.java +++ b/document-readers/tika-reader/src/test/java/org/springframework/ai/reader/tika/TikaDocumentReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.tika; import org.junit.jupiter.params.ParameterizedTest; diff --git a/models/spring-ai-anthropic/pom.xml b/models/spring-ai-anthropic/pom.xml index beacc8533f0..b2461539485 100644 --- a/models/spring-ai-anthropic/pom.xml +++ b/models/spring-ai-anthropic/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 05ece850a56..d67431a9854 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic; import java.util.ArrayList; @@ -28,6 +29,9 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; @@ -42,7 +46,11 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.model.*; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; @@ -61,9 +69,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * The {@link ChatModel} implementation for the Anthropic service. * @@ -76,16 +81,21 @@ */ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel { - private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class); - - private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); - public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue(); public static final Integer DEFAULT_MAX_TOKENS = 500; public static final Double DEFAULT_TEMPERATURE = 0.8; + private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class); + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + /** + * The retry template used to retry the OpenAI API calls. + */ + public final RetryTemplate retryTemplate; + /** * The lower-level API for the Anthropic service. */ @@ -96,11 +106,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM */ private final AnthropicChatOptions defaultOptions; - /** - * The retry template used to retry the OpenAI API calls. - */ - public final RetryTemplate retryTemplate; - /** * Observation registry used for instrumentation. */ diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index c40539f8c91..08ea7de8cf3 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic; import java.util.ArrayList; @@ -91,91 +92,24 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final AnthropicChatOptions options = new AnthropicChatOptions(); - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withModel(AnthropicApi.ChatModel model) { - this.options.model = model.getValue(); - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withMetadata(ChatCompletionRequest.Metadata metadata) { - this.options.metadata = metadata; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.stopSequences = stopSequences; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withTopK(Integer topK) { - this.options.topK = topK; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public AnthropicChatOptions build() { - return this.options; - } - + public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { + return builder().withModel(fromOptions.getModel()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withMetadata(fromOptions.getMetadata()) + .withStopSequences(fromOptions.getStopSequences()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); } @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -293,19 +227,86 @@ public AnthropicChatOptions copy() { return fromOptions(this); } - public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { - return builder().withModel(fromOptions.getModel()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withMetadata(fromOptions.getMetadata()) - .withStopSequences(fromOptions.getStopSequences()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); + public static class Builder { + + private final AnthropicChatOptions options = new AnthropicChatOptions(); + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withModel(AnthropicApi.ChatModel model) { + this.options.model = model.getValue(); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withMetadata(ChatCompletionRequest.Metadata metadata) { + this.options.metadata = metadata; + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.stopSequences = stopSequences; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.options.topK = topK; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public AnthropicChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java index bf56e842b76..71a47d1e0db 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.aot; import org.springframework.ai.anthropic.api.AnthropicApi; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 5d893ba6eda..35fa4faf6fb 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api; import java.util.ArrayList; @@ -23,6 +24,14 @@ import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -38,15 +47,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * @author Christian Tzolov * @author Mariusz Bernacki @@ -57,12 +57,6 @@ public class AnthropicApi { public static final String PROVIDER_NAME = AiProvider.ANTHROPIC.value(); - private static final String HEADER_X_API_KEY = "x-api-key"; - - private static final String HEADER_ANTHROPIC_VERSION = "anthropic-version"; - - private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta"; - public static final String DEFAULT_BASE_URL = "https://api.anthropic.com"; public static final String DEFAULT_ANTHROPIC_VERSION = "2023-06-01"; @@ -71,10 +65,18 @@ public class AnthropicApi { public static final String BETA_MAX_TOKENS = "max-tokens-3-5-sonnet-2024-07-15"; + private static final String HEADER_X_API_KEY = "x-api-key"; + + private static final String HEADER_ANTHROPIC_VERSION = "anthropic-version"; + + private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; + private final StreamHelper streamHelper = new StreamHelper(); + private WebClient webClient; /** @@ -141,6 +143,74 @@ public AnthropicApi(String baseUrl, String anthropicApiKey, String anthropicVers .build(); } + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletionResponse} as a body and HTTP + * status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/messages") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletionResponse.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + AtomicReference chatCompletionReference = new AtomicReference<>(); + + return this.webClient.post() + .uri("/v1/messages") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) + .filter(event -> event.type() != EventType.PING) + // Detect if the chunk is part of a streaming function call. + .map(event -> { + if (this.streamHelper.isToolUseStart(event)) { + isInsideTool.set(true); + } + return event; + }) + // Group all chunks belonging to the same function call. + .windowUntil(event -> { + if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + .concatMapIterable(window -> { + Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), + this.streamHelper::mergeToolUseEvents); + return List.of(monoChunk); + }) + .flatMap(mono -> mono) + .map(event -> this.streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) + .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); + } + /** * Check the Models * overview and models for @@ -257,6 +407,14 @@ public ChatCompletionRequest(String model, List messages, Stri this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null); } + public static ChatCompletionRequestBuilder builder() { + return new ChatCompletionRequestBuilder(); + } + + public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { + return new ChatCompletionRequestBuilder(request); + } + /** * @param userId An external identifier for the user who is associated with the * request. This should be a uuid, hash value, or other opaque identifier. @@ -265,15 +423,9 @@ public ChatCompletionRequest(String model, List messages, Stri */ @JsonInclude(Include.NON_NULL) public record Metadata(@JsonProperty("user_id") String userId) { - } - public static ChatCompletionRequestBuilder builder() { - return new ChatCompletionRequestBuilder(); } - public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { - return new ChatCompletionRequestBuilder(request); - } } public static class ChatCompletionRequestBuilder { @@ -378,12 +530,16 @@ public ChatCompletionRequestBuilder withTools(List tools) { } public ChatCompletionRequest build() { - return new ChatCompletionRequest(model, messages, system, maxTokens, metadata, stopSequences, stream, - temperature, topP, topK, tools); + return new ChatCompletionRequest(this.model, this.messages, this.system, this.maxTokens, this.metadata, + this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools); } } + /////////////////////////////////////// + /// ERROR EVENT + /////////////////////////////////////// + /** * Input messages. * @@ -535,9 +691,15 @@ public record Source( // @formatter:off public Source(String mediaType, String data) { this("base64", mediaType, data); } + } + } + /////////////////////////////////////// + /// CONTENT_BLOCK EVENTS + /////////////////////////////////////// + @JsonInclude(Include.NON_NULL) public record Tool(// @formatter:off @JsonProperty("name") String name, @@ -546,6 +708,8 @@ public record Tool(// @formatter:off // @formatter:on } + // CB START EVENT + /** * @param id Unique object identifier. The format and length of IDs may change over * time. @@ -572,6 +736,8 @@ public record ChatCompletionResponse( // @formatter:off // @formatter:on } + // CB DELTA EVENT + /** * Usage statistics. * @@ -585,94 +751,7 @@ public record Usage( // @formatter:off // @formatter:off } - - /////////////////////////////////////// - /// ERROR EVENT - /////////////////////////////////////// - - /** - * The evnt type of the streamed chunk. - */ - public enum EventType { - - /** - * Message start event. Contains a Message object with empty content. - */ - @JsonProperty("message_start") - MESSAGE_START, - - /** - * Message delta event, indicating top-level changes to the final Message object. - */ - @JsonProperty("message_delta") - MESSAGE_DELTA, - - /** - * A final message stop event. - */ - @JsonProperty("message_stop") - MESSAGE_STOP, - - /** - * - */ - @JsonProperty("content_block_start") - CONTENT_BLOCK_START, - - /** - * - */ - @JsonProperty("content_block_delta") - CONTENT_BLOCK_DELTA, - - /** - * - */ - @JsonProperty("content_block_stop") - CONTENT_BLOCK_STOP, - - /** - * - */ - @JsonProperty("error") - ERROR, - - /** - * - */ - @JsonProperty("ping") - PING, - - /** - * Artifically created event to aggregate tool use events. - */ - TOOL_USE_AGGREATE; - - } - - @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", - visible = true) - @JsonSubTypes({ @JsonSubTypes.Type(value = ContentBlockStartEvent.class, name = "content_block_start"), - @JsonSubTypes.Type(value = ContentBlockDeltaEvent.class, name = "content_block_delta"), - @JsonSubTypes.Type(value = ContentBlockStopEvent.class, name = "content_block_stop"), - - @JsonSubTypes.Type(value = PingEvent.class, name = "ping"), - - @JsonSubTypes.Type(value = ErrorEvent.class, name = "error"), - - @JsonSubTypes.Type(value = MessageStartEvent.class, name = "message_start"), - @JsonSubTypes.Type(value = MessageDeltaEvent.class, name = "message_delta"), - @JsonSubTypes.Type(value = MessageStopEvent.class, name = "message_stop") }) - public interface StreamEvent { - - @JsonProperty("type") - EventType type(); - - } - - /////////////////////////////////////// - /// CONTENT_BLOCK EVENTS - /////////////////////////////////////// + /// ECB STOP /** * Special event used to aggregate multiple tool use events into a single event with @@ -736,13 +815,17 @@ void squashIntoContentBlock() { @Override public String toString() { - return "EventToolUseBuilder [index=" + index + ", id=" + id + ", name=" + name + ", partialJson=" - + partialJson + ", toolUseMap=" + toolContentBlocks + "]"; + return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson=" + + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]"; } } - // CB START EVENT + /////////////////////////////////////// + /// MESSAGE EVENTS + /////////////////////////////////////// + + // MESSAGE START EVENT @JsonInclude(Include.NON_NULL) public record ContentBlockStartEvent(// @formatter:off @@ -773,7 +856,7 @@ public record ContentBlockText( } }// @formatter:on - // CB DELTA EVENT + // MESSAGE DELTA EVENT @JsonInclude(Include.NON_NULL) public record ContentBlockDeltaEvent(// @formatter:off @@ -803,7 +886,7 @@ public record ContentBlockDeltaJson( } }// @formatter:on - /// ECB STOP + // MESSAGE STOP EVENT @JsonInclude(Include.NON_NULL) public record ContentBlockStopEvent(// @formatter:off @@ -811,20 +894,12 @@ public record ContentBlockStopEvent(// @formatter:off @JsonProperty("index") Integer index) implements StreamEvent { }// @formatter:on - /////////////////////////////////////// - /// MESSAGE EVENTS - /////////////////////////////////////// - - // MESSAGE START EVENT - @JsonInclude(Include.NON_NULL) public record MessageStartEvent(// @formatter:off @JsonProperty("type") EventType type, @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { }// @formatter:on - // MESSAGE DELTA EVENT - @JsonInclude(Include.NON_NULL) public record MessageDeltaEvent(// @formatter:off @JsonProperty("type") EventType type, @@ -843,8 +918,6 @@ public record MessageDeltaUsage( } }// @formatter:on - // MESSAGE STOP EVENT - @JsonInclude(Include.NON_NULL) public record MessageStopEvent(// @formatter:off @JsonProperty("type") EventType type) implements StreamEvent { @@ -873,74 +946,4 @@ public record PingEvent(// @formatter:off @JsonProperty("type") EventType type) implements StreamEvent { }// @formatter:on - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletionResponse} as a body and HTTP - * status code and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/messages") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletionResponse.class); - } - - private final StreamHelper streamHelper = new StreamHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - AtomicReference chatCompletionReference = new AtomicReference<>(); - - return this.webClient.post() - .uri("/v1/messages") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) - .filter(event -> event.type() != EventType.PING) - // Detect if the chunk is part of a streaming function call. - .map(event -> { - if (this.streamHelper.isToolUseStart(event)) { - isInsideTool.set(true); - } - return event; - }) - // Group all chunks belonging to the same function call. - .windowUntil(event -> { - if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - .concatMapIterable(window -> { - Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), - this.streamHelper::mergeToolUseEvents); - return List.of(monoChunk); - }) - .flatMap(mono -> mono) - .map(event -> streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) - .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); - } - } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index 054bf023bb7..677bdb2e49a 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api; import java.util.ArrayList; @@ -22,22 +23,22 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; -import org.springframework.ai.anthropic.api.AnthropicApi.Role; -import org.springframework.ai.anthropic.api.AnthropicApi.Usage; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.ToolUseAggregationEvent; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; -import org.springframework.ai.anthropic.api.AnthropicApi.MessageDeltaEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.MessageStartEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaJson; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaText; +import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockText; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockToolUse; import org.springframework.ai.anthropic.api.AnthropicApi.EventType; +import org.springframework.ai.anthropic.api.AnthropicApi.MessageDeltaEvent; +import org.springframework.ai.anthropic.api.AnthropicApi.MessageStartEvent; +import org.springframework.ai.anthropic.api.AnthropicApi.Role; +import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent; +import org.springframework.ai.anthropic.api.AnthropicApi.ToolUseAggregationEvent; +import org.springframework.ai.anthropic.api.AnthropicApi.Usage; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * Helper class to support streaming function calling. diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicRateLimit.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicRateLimit.java index 0ed5cdde1fe..83edc7e15d1 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicRateLimit.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicRateLimit.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.metadata; import java.time.Duration; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java index 1de5edc6aa7..fbafc2297d3 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.metadata; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -27,10 +28,6 @@ */ public class AnthropicUsage implements Usage { - public static AnthropicUsage from(AnthropicApi.Usage usage) { - return new AnthropicUsage(usage); - } - private final AnthropicApi.Usage usage; protected AnthropicUsage(AnthropicApi.Usage usage) { @@ -38,6 +35,10 @@ protected AnthropicUsage(AnthropicApi.Usage usage) { this.usage = usage; } + public static AnthropicUsage from(AnthropicApi.Usage usage) { + return new AnthropicUsage(usage); + } + protected AnthropicApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index 1a816fc1850..144a9453788 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.anthropic; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic; import java.io.IOException; import java.util.ArrayList; @@ -30,11 +29,12 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.tool.MockWeatherService; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; @@ -47,6 +47,7 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -59,7 +60,7 @@ import org.springframework.util.MimeTypeUtils; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = AnthropicChatModelIT.Config.class, properties = "spring.ai.retry.on-http-codes=429") @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+") @@ -76,17 +77,25 @@ class AnthropicChatModelIT { @Value("classpath:/prompts/system-message.st") private Resource systemResource; + private static void validateChatResponseMetadata(ChatResponse response, String model) { + assertThat(response.getMetadata().getId()).isNotEmpty(); + assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20241022" }) void roleTest(String modelName) { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), AnthropicChatOptions.builder().withModel(modelName).build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); @@ -103,17 +112,17 @@ void roleTest(String modelName) { void testMessageHistory() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), AnthropicChatOptions.builder().withModel("claude-3-sonnet-20240229").build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(), new UserMessage("Repeat the last assistant message."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @@ -167,16 +176,13 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = mapOutputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -189,7 +195,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = beanOutputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -210,7 +216,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -234,7 +240,7 @@ void multiModalityTest() throws IOException { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel.call(new Prompt(List.of(userMessage))); + var response = this.chatModel.call(new Prompt(List.of(userMessage))); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).contains("banan", "apple", "basket"); @@ -257,7 +263,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -284,7 +290,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -301,7 +307,7 @@ void streamFunctionCallTest() { void validateCallResponseMetadata() { String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName(); // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -316,7 +322,7 @@ void validateCallResponseMetadata() { void validateStreamCallResponseMetadata() { String model = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName(); // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .stream() @@ -328,12 +334,8 @@ void validateStreamCallResponseMetadata() { validateChatResponseMetadata(response, model); } - private static void validateChatResponseMetadata(ChatResponse response, String model) { - assertThat(response.getMetadata().getId()).isNotEmpty(); - assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); - assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); + record ActorsFilmsRecord(String actor, List movies) { + } @SpringBootConfiguration @@ -360,4 +362,4 @@ public AnthropicChatModel openAiChatModel(AnthropicApi api) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java index 17ef19c909e..968e0c6a113 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.anthropic; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic; import java.util.List; import java.util.stream.Collectors; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; @@ -39,9 +42,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation in {@link AnthropicChatModel}. @@ -61,7 +62,7 @@ public class AnthropicChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -77,7 +78,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -99,7 +100,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -121,7 +122,7 @@ void observationForStreamingChatOperation() { } private void validate(ChatResponseMetadata responseMetadata, String finishReasons) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicTestConfiguration.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicTestConfiguration.java index e92f4d67041..e90a94f8764 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicTestConfiguration.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic; import org.springframework.ai.anthropic.api.AnthropicApi; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java index be251b00dfd..3dec1f7bcd0 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic; import org.junit.jupiter.api.Test; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/EventParsingTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/EventParsingTests.java index d57bf765b1c..9cd11068bfa 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/EventParsingTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/EventParsingTests.java @@ -1,34 +1,35 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.anthropic; + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic; import java.io.IOException; import java.nio.charset.Charset; import java.util.List; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent; import org.springframework.core.io.DefaultResourceLoader; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -44,6 +45,7 @@ public void readEvents() throws IOException { .getContentAsString(Charset.defaultCharset()); List events = new ObjectMapper().readerFor(new TypeReference>() { + }).readValue(json); logger.info(events.toString()); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java index 11683d844b8..f38a6c8e671 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index ceaebdfe6d2..d830980771b 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.anthropic.api; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic.api; import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; @@ -28,7 +29,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.http.ResponseEntity; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -43,7 +44,7 @@ void chatCompletionEntity() { AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?")), Role.USER); - ResponseEntity response = anthropicApi + ResponseEntity response = this.anthropicApi .chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, false)); @@ -58,7 +59,7 @@ void chatCompletionStream() { AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?")), Role.USER); - Flux response = anthropicApi.chatCompletionStream(new ChatCompletionRequest( + Flux response = this.anthropicApi.chatCompletionStream(new ChatCompletionRequest( AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, true)); assertThat(response).isNotNull(); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiLegacyToolIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiLegacyToolIT.java index 6e9440e0af2..0be31a1386d 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiLegacyToolIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiLegacyToolIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api.tool; import java.util.List; @@ -25,10 +26,10 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; +import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; -import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.ai.anthropic.api.tool.XmlHelper.FunctionCalls; import org.springframework.ai.anthropic.api.tool.XmlHelper.Tools; @@ -60,10 +61,6 @@ @SuppressWarnings("null") public class AnthropicApiLegacyToolIT { - private static final Logger logger = LoggerFactory.getLogger(AnthropicApiLegacyToolIT.class); - - AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); - public static final String TOO_SYSTEM_PROMPT_TEMPLATE = """ In this environment you have access to a set of tools you can use to answer the user's question. @@ -84,9 +81,9 @@ public class AnthropicApiLegacyToolIT { public static final ConcurrentHashMap FUNCTIONS = new ConcurrentHashMap<>(); - static { - FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); - } + private static final Logger logger = LoggerFactory.getLogger(AnthropicApiLegacyToolIT.class); + + AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); @Test void toolCalls() { @@ -120,7 +117,7 @@ void toolCalls() { private ResponseEntity doCall(ChatCompletionRequest chatCompletionRequest) { - ResponseEntity response = anthropicApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity response = this.anthropicApi.chatCompletionEntity(chatCompletionRequest); FunctionCalls functionCalls = XmlHelper.extractFunctionCalls(response.getBody().content().get(0).text()); @@ -150,4 +147,8 @@ private ResponseEntity doCall(ChatCompletionRequest chat List.of(chatCompletionMessage2), null, 500, 0.8, false)); } + static { + FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); + } + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java index c5642976267..767d73d0d68 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api.tool; import java.util.ArrayList; @@ -26,11 +27,11 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; +import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; -import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.ai.anthropic.api.AnthropicApi.Tool; import org.springframework.ai.model.ModelOptionsUtils; @@ -53,16 +54,12 @@ @SuppressWarnings("null") public class AnthropicApiToolIT { + public static final ConcurrentHashMap FUNCTIONS = new ConcurrentHashMap<>(); + private static final Logger logger = LoggerFactory.getLogger(AnthropicApiLegacyToolIT.class); AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); - public static final ConcurrentHashMap FUNCTIONS = new ConcurrentHashMap<>(); - - static { - FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); - } - List tools = List.of(new Tool("getCurrentWeather", "Get the weather in location. Return temperature in 30°F or 30°C format.", ModelOptionsUtils.jsonToMap(""" { @@ -109,10 +106,10 @@ private ResponseEntity doCall(List mes .withMessages(messageConversation) .withMaxTokens(1500) .withTemperature(0.8) - .withTools(tools) + .withTools(this.tools) .build(); - ResponseEntity response = anthropicApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity response = this.anthropicApi.chatCompletionEntity(chatCompletionRequest); List toolToUseList = response.getBody() .content() @@ -155,4 +152,8 @@ private ResponseEntity doCall(List mes return doCall(messageConversation); } + static { + FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); + } + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java index 762f60fd256..8af45829870 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api.tool; import java.util.function.Function; @@ -28,14 +29,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, Unit.C); } /** @@ -64,26 +72,21 @@ private Unit(String text) { } /** - * Weather Function response. + * Weather Function request. */ - public record Response(double temp, Unit unit) { - } + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { - @Override - public Response apply(Request request) { + } - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } + /** + * Weather Function response. + */ + public record Response(double temp, Unit unit) { - return new Response(temperature, Unit.C); } -} \ No newline at end of file +} diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/XmlHelper.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/XmlHelper.java index d2b5c7a7eaa..9ea40d1c800 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/XmlHelper.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/XmlHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api.tool; import java.util.List; @@ -45,45 +46,6 @@ public class XmlHelper { private static final XmlMapper xmlMapper = new XmlMapper(); - @JsonInclude(Include.NON_NULL) // @formatter:off - @JacksonXmlRootElement(localName = "tools") - public record Tools( - @JacksonXmlElementWrapper(useWrapping = false) @JsonProperty("tool_description") List toolDescriptions) { - - public record ToolDescription( - @JsonProperty("tool_name") String toolName, - @JsonProperty("description") String description, - @JacksonXmlElementWrapper(localName = "parameters") @JsonProperty("parameter") List parameters) { - - @JacksonXmlRootElement(localName = "parameter") - public record Parameter( - @JsonProperty("name") String name, - @JsonProperty("type") String type, - @JsonProperty("description") String description) { - } - } - } // @formatter:on - - @JsonInclude(Include.NON_NULL) // @formatter:off - @JacksonXmlRootElement(localName = "function_calls") - public record FunctionCalls(@JsonProperty("invoke") Invoke invoke) { - public record Invoke( - @JsonProperty("tool_name") String toolName, - @JsonProperty("parameters") Map parameters) { - } - } // @formatter:on - - @JsonInclude(Include.NON_NULL) // @formatter:off - @JacksonXmlRootElement(localName = "function_results") - public record FunctionResults( - @JacksonXmlElementWrapper(useWrapping = false) @JsonProperty("result") List result) { - - public record Result( - @JsonProperty("tool_name") String toolName, - @JsonProperty("stdout") Object stdout) { - } - } // @formatter:on - public static String extractFunctionCallsXmlBlock(String text) { if (!StringUtils.hasText(text)) { return ""; @@ -149,4 +111,43 @@ public static void main(String[] args) throws JsonMappingException, JsonProcessi } + @JsonInclude(Include.NON_NULL) // @formatter:off + @JacksonXmlRootElement(localName = "tools") + public record Tools( + @JacksonXmlElementWrapper(useWrapping = false) @JsonProperty("tool_description") List toolDescriptions) { + + public record ToolDescription( + @JsonProperty("tool_name") String toolName, + @JsonProperty("description") String description, + @JacksonXmlElementWrapper(localName = "parameters") @JsonProperty("parameter") List parameters) { + + @JacksonXmlRootElement(localName = "parameter") + public record Parameter( + @JsonProperty("name") String name, + @JsonProperty("type") String type, + @JsonProperty("description") String description) { + } + } + } // @formatter:on + + @JsonInclude(Include.NON_NULL) // @formatter:off + @JacksonXmlRootElement(localName = "function_calls") + public record FunctionCalls(@JsonProperty("invoke") Invoke invoke) { + public record Invoke( + @JsonProperty("tool_name") String toolName, + @JsonProperty("parameters") Map parameters) { + } + } // @formatter:on + + @JsonInclude(Include.NON_NULL) // @formatter:off + @JacksonXmlRootElement(localName = "function_results") + public record FunctionResults( + @JacksonXmlElementWrapper(useWrapping = false) @JsonProperty("result") List result) { + + public record Result( + @JsonProperty("tool_name") String toolName, + @JsonProperty("stdout") Object stdout) { + } + } // @formatter:on + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java index 7fd9a6e6c27..d93a84e004c 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.anthropic.client; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic.client; import java.io.IOException; import java.net.URL; @@ -31,6 +30,8 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.AnthropicTestConfiguration; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -51,7 +52,7 @@ import org.springframework.test.context.ActiveProfiles; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = AnthropicTestConfiguration.class, properties = "spring.ai.retry.on-http-codes=429") @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+") @@ -66,16 +67,13 @@ class AnthropicChatClientIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { - } - @Test void call() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .advisors(new SimpleLoggerAdvisor()) - .system(s -> s.text(systemTextResource) + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -91,7 +89,7 @@ void call() { @Test void listOutputConverterString() { // @formatter:off - List collection = ChatClient.create(chatModel).prompt() + List collection = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -106,7 +104,7 @@ void listOutputConverterString() { void listOutputConverterBean() { // @formatter:off - List actorsFilms = ChatClient.create(chatModel).prompt() + List actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.") .call() .entity(new ParameterizedTypeReference>() { @@ -123,7 +121,7 @@ void customOutputConverter() { var toStringListConverter = new ListOutputConverter(new DefaultConversionService()); // @formatter:off - List flavors = ChatClient.create(chatModel).prompt() + List flavors = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -138,7 +136,7 @@ void customOutputConverter() { @Test void mapOutputConverter() { // @formatter:off - Map result = ChatClient.create(chatModel).prompt() + Map result = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("Provide me a List of {subject}") .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'")) .call() @@ -153,7 +151,7 @@ void mapOutputConverter() { void beanOutputConverter() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography for a random actor.") .call() .entity(ActorsFilms.class); @@ -167,7 +165,7 @@ void beanOutputConverter() { void beanOutputConverterRecords() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks.") .call() .entity(ActorsFilms.class); @@ -184,7 +182,7 @@ void beanStreamOutputConverterRecords() { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); // @formatter:off - Flux chatResponse = ChatClient.create(chatModel) + Flux chatResponse = ChatClient.create(this.chatModel) .prompt() .advisors(new SimpleLoggerAdvisor()) .user(u -> u @@ -211,7 +209,7 @@ void beanStreamOutputConverterRecords() { void functionCallTest() { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")) .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .call() @@ -227,7 +225,7 @@ void functionCallTest() { void defaultFunctionCallTest() { // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")) .build() @@ -245,7 +243,7 @@ void defaultFunctionCallTest() { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .stream() @@ -264,7 +262,7 @@ void streamFunctionCallTest() { void multiModalityEmbeddedImage(String modelName) throws IOException { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png"))) @@ -287,7 +285,7 @@ void multiModalityImageUrl(String modelName) throws IOException { URL url = new URL("https://docs.spring.io/spring-ai/reference/1.0.0-SNAPSHOT/_images/multimodal.test.png"); // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to .options(AnthropicChatOptions.builder().withModel(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) @@ -304,7 +302,7 @@ void multiModalityImageUrl(String modelName) throws IOException { void streamingMultiModality() throws IOException { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET) .build()) .user(u -> u.text("Explain what do you see on this picture?") @@ -320,4 +318,8 @@ void streamingMultiModality() throws IOException { assertThat(content).containsAnyOf("bowl", "basket"); } -} \ No newline at end of file + record ActorsFilms(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-anthropic/src/test/resources/application-logging-test.properties b/models/spring-ai-anthropic/src/test/resources/application-logging-test.properties index 8e8b3b2c3c6..4466a718052 100644 --- a/models/spring-ai-anthropic/src/test/resources/application-logging-test.properties +++ b/models/spring-ai-anthropic/src/test/resources/application-logging-test.properties @@ -1 +1,17 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + logging.level.org.springframework.ai.chat.client.advisor=DEBUG diff --git a/models/spring-ai-azure-openai/pom.xml b/models/spring-ai-azure-openai/pom.xml index 63443562899..101b9e508d7 100644 --- a/models/spring-ai-azure-openai/pom.xml +++ b/models/spring-ai-azure-openai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java index 1d1e4afd941..314925b3a54 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; +import java.io.IOException; +import java.util.List; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.AudioTranscriptionFormat; import com.azure.ai.openai.models.AudioTranscriptionOptions; import com.azure.ai.openai.models.AudioTranscriptionTimestampGranularity; import com.azure.core.http.rest.Response; + import org.springframework.ai.audio.transcription.AudioTranscription; import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; @@ -35,9 +40,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.io.IOException; -import java.util.List; - /** * AzureOpenAI audio transcription client implementation for backed by * {@link OpenAIClient}. You provide as input the audio file you want to transcribe and @@ -61,6 +63,15 @@ public AzureOpenAiAudioTranscriptionModel(OpenAIClient openAIClient, AzureOpenAi this.defaultOptions = options; } + private static byte[] toBytes(Resource resource) { + try { + return resource.getInputStream().readAllBytes(); + } + catch (IOException e) { + throw new IllegalArgumentException("Failed to read resource: " + resource, e); + } + } + public String call(Resource audioResource) { AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioResource); return call(transcriptionRequest).getResult().getOutput(); @@ -73,7 +84,7 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt audioTranscripti AudioTranscriptionFormat responseFormat = audioTranscriptionOptions.getResponseFormat(); if (JSON_FORMATS.contains(responseFormat)) { - var audioTranscription = openAIClient.getAudioTranscription(deploymentOrModelName, FILENAME_MARKER, + var audioTranscription = this.openAIClient.getAudioTranscription(deploymentOrModelName, FILENAME_MARKER, audioTranscriptionOptions); List words = null; @@ -108,7 +119,7 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt audioTranscripti return new AudioTranscriptionResponse(transcript, metadata); } else { - Response audioTranscription = openAIClient.getAudioTranscriptionTextWithResponse( + Response audioTranscription = this.openAIClient.getAudioTranscriptionTextWithResponse( deploymentOrModelName, FILENAME_MARKER, audioTranscriptionOptions, null); String text = audioTranscription.getValue(); AudioTranscription transcript = new AudioTranscription(text); @@ -119,7 +130,7 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt audioTranscripti private String getDeploymentName(AudioTranscriptionPrompt audioTranscriptionPrompt) { var runtimeOptions = audioTranscriptionPrompt.getOptions(); - if (defaultOptions != null) { + if (this.defaultOptions != null) { runtimeOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, AzureOpenAiAudioTranscriptionOptions.class); } @@ -189,13 +200,4 @@ private AudioTranscriptionOptions toAudioTranscriptionOptions(AudioTranscription return audioTranscriptionOptions; } - private static byte[] toBytes(Resource resource) { - try { - return resource.getInputStream().readAllBytes(); - } - catch (IOException e) { - throw new IllegalArgumentException("Failed to read resource: " + resource, e); - } - } - } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionOptions.java index bd80aace91a..b79e2588518 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; +import java.util.List; + import com.azure.ai.openai.models.AudioTranscriptionFormat; import com.azure.ai.openai.models.AudioTranscriptionTimestampGranularity; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.audio.transcription.AudioTranscriptionOptions; import org.springframework.util.Assert; -import java.util.List; - /** * @author Piotr Olaszewski */ @@ -66,62 +68,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected AzureOpenAiAudioTranscriptionOptions options; - - public Builder() { - this.options = new AzureOpenAiAudioTranscriptionOptions(); - } - - public Builder(AzureOpenAiAudioTranscriptionOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withDeploymentName(String deploymentName) { - this.options.setDeploymentName(deploymentName); - return this; - } - - public Builder withLanguage(String language) { - this.options.language = language; - return this; - } - - public Builder withPrompt(String prompt) { - this.options.prompt = prompt; - return this; - } - - public Builder withResponseFormat(TranscriptResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withTemperature(Float temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withGranularityType(List granularityType) { - this.options.granularityType = granularityType; - return this; - } - - public AzureOpenAiAudioTranscriptionOptions build() { - Assert.hasText(options.model, "model must not be empty"); - Assert.notNull(options.responseFormat, "response_format must not be null"); - - return this.options; - } - - } - @Override public String getModel() { return this.model; @@ -132,7 +78,7 @@ public void setModel(String model) { } public String getDeploymentName() { - return deploymentName; + return this.deploymentName; } public void setDeploymentName(String deploymentName) { @@ -163,7 +109,6 @@ public void setTemperature(Float temperature) { this.temperature = temperature; } - public TranscriptResponseFormat getResponseFormat() { return this.responseFormat; } @@ -184,10 +129,10 @@ public void setGranularityType(List granularityType) { public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((prompt == null) ? 0 : prompt.hashCode()); - result = prime * result + ((language == null) ? 0 : language.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.prompt == null) ? 0 : this.prompt.hashCode()); + result = prime * result + ((this.language == null) ? 0 : this.language.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); return result; } @@ -204,7 +149,7 @@ public boolean equals(Object obj) { if (other.model != null) return false; } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) return false; if (this.prompt == null) { if (other.prompt != null) @@ -237,7 +182,109 @@ public enum WhisperModel { } public String getValue() { - return value; + return this.value; + } + + } + + public enum TranscriptResponseFormat { + + // @formatter:off + @JsonProperty("json") JSON(AudioTranscriptionFormat.JSON, StructuredResponse.class), + @JsonProperty("text") TEXT(AudioTranscriptionFormat.TEXT, String.class), + @JsonProperty("srt") SRT(AudioTranscriptionFormat.SRT, String.class), + @JsonProperty("verbose_json") VERBOSE_JSON(AudioTranscriptionFormat.VERBOSE_JSON, StructuredResponse.class), + @JsonProperty("vtt") VTT(AudioTranscriptionFormat.VTT, String.class); + + public final AudioTranscriptionFormat value; + + public final Class responseType; + + TranscriptResponseFormat(AudioTranscriptionFormat value, Class responseType) { + this.value = value; + this.responseType = responseType; + } + + public AudioTranscriptionFormat getValue() { + return this.value; + } + + public Class getResponseType() { + return this.responseType; + } + } + + public enum GranularityType { + + // @formatter:off + @JsonProperty("word") WORD(AudioTranscriptionTimestampGranularity.WORD), + @JsonProperty("segment") SEGMENT(AudioTranscriptionTimestampGranularity.SEGMENT); + // @formatter:on + + public final AudioTranscriptionTimestampGranularity value; + + GranularityType(AudioTranscriptionTimestampGranularity value) { + this.value = value; + } + + public AudioTranscriptionTimestampGranularity getValue() { + return this.value; + } + + } + + public static class Builder { + + protected AzureOpenAiAudioTranscriptionOptions options; + + public Builder() { + this.options = new AzureOpenAiAudioTranscriptionOptions(); + } + + public Builder(AzureOpenAiAudioTranscriptionOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withDeploymentName(String deploymentName) { + this.options.setDeploymentName(deploymentName); + return this; + } + + public Builder withLanguage(String language) { + this.options.language = language; + return this; + } + + public Builder withPrompt(String prompt) { + this.options.prompt = prompt; + return this; + } + + public Builder withResponseFormat(TranscriptResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withGranularityType(List granularityType) { + this.options.granularityType = granularityType; + return this; + } + + public AzureOpenAiAudioTranscriptionOptions build() { + Assert.hasText(this.options.model, "model must not be empty"); + Assert.notNull(this.options.responseFormat, "response_format must not be null"); + + return this.options; } } @@ -308,51 +355,6 @@ public record Segment( @JsonProperty("no_speech_prob") Float noSpeechProb) { // @formatter:on } - } - - public enum TranscriptResponseFormat { - - // @formatter:off - @JsonProperty("json") JSON(AudioTranscriptionFormat.JSON, StructuredResponse.class), - @JsonProperty("text") TEXT(AudioTranscriptionFormat.TEXT, String.class), - @JsonProperty("srt") SRT(AudioTranscriptionFormat.SRT, String.class), - @JsonProperty("verbose_json") VERBOSE_JSON(AudioTranscriptionFormat.VERBOSE_JSON, StructuredResponse.class), - @JsonProperty("vtt") VTT(AudioTranscriptionFormat.VTT, String.class); - - public final AudioTranscriptionFormat value; - - public final Class responseType; - - TranscriptResponseFormat(AudioTranscriptionFormat value, Class responseType) { - this.value = value; - this.responseType = responseType; - } - - public AudioTranscriptionFormat getValue() { - return this.value; - } - - public Class getResponseType() { - return this.responseType; - } - } - - public enum GranularityType { - - // @formatter:off - @JsonProperty("word") WORD(AudioTranscriptionTimestampGranularity.WORD), - @JsonProperty("segment") SEGMENT(AudioTranscriptionTimestampGranularity.SEGMENT); - // @formatter:on - - public final AudioTranscriptionTimestampGranularity value; - - GranularityType(AudioTranscriptionTimestampGranularity value) { - this.value = value; - } - - public AudioTranscriptionTimestampGranularity getValue() { - return this.value; - } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 2436c7d3d65..77c2ea0b245 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,49 @@ package org.springframework.ai.azure.openai; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.ai.openai.models.*; +import com.azure.ai.openai.models.ChatChoice; +import com.azure.ai.openai.models.ChatCompletions; +import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall; +import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition; +import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatCompletionsResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsToolCall; +import com.azure.ai.openai.models.ChatCompletionsToolDefinition; +import com.azure.ai.openai.models.ChatMessageContentItem; +import com.azure.ai.openai.models.ChatMessageImageContentItem; +import com.azure.ai.openai.models.ChatMessageImageUrl; +import com.azure.ai.openai.models.ChatMessageTextContentItem; +import com.azure.ai.openai.models.ChatRequestAssistantMessage; +import com.azure.ai.openai.models.ChatRequestMessage; +import com.azure.ai.openai.models.ChatRequestSystemMessage; +import com.azure.ai.openai.models.ChatRequestToolMessage; +import com.azure.ai.openai.models.ChatRequestUserMessage; +import com.azure.ai.openai.models.CompletionsFinishReason; +import com.azure.ai.openai.models.ContentFilterResultsForPrompt; +import com.azure.ai.openai.models.FunctionCall; +import com.azure.ai.openai.models.FunctionDefinition; import com.azure.core.util.BinaryData; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -54,20 +90,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.Base64; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; - /** * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by * {@link OpenAIClient}. @@ -153,6 +175,19 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi this.observationRegistry = observationRegistry; } + public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { + Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); + String id = chatCompletions.getId(); + Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage(); + return ChatResponseMetadata.builder() + .withId(id) + .withUsage(usage) + .withModel(chatCompletions.getModel()) + .withPromptMetadata(promptFilterMetadata) + .withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint()) + .build(); + } + public AzureOpenAiChatOptions getDefaultOptions() { return AzureOpenAiChatOptions.fromOptions(this.defaultOptions); } @@ -302,19 +337,6 @@ private Generation buildGeneration(ChatChoice choice, Map metada return new Generation(assistantMessage, generationMetadata); } - public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { - Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); - String id = chatCompletions.getId(); - Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage(); - return ChatResponseMetadata.builder() - .withId(id) - .withUsage(usage) - .withModel(chatCompletions.getModel()) - .withPromptMetadata(promptFilterMetadata) - .withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint()) - .build(); - } - /** * Test access. */ @@ -332,8 +354,9 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { options = this.merge(options, this.defaultOptions); - if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) + if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { functionsForThisRequest.addAll(this.defaultOptions.getFunctions()); + } if (prompt.getOptions() != null) { AzureOpenAiChatOptions updatedRuntimeOptions; @@ -428,14 +451,16 @@ private List fromSpringAiMessage(Message message) { private String getMediaUrl(Media media) { Object data = media.getData(); - if (data instanceof String dataUrl) + if (data instanceof String dataUrl) { return dataUrl; + } else if (data instanceof byte[] dataBytes) { String base64EncodedData = Base64.getEncoder().encodeToString(dataBytes); return "data:" + media.getMimeType() + ";base64," + base64EncodedData; } - else + else { throw new IllegalArgumentException("Unknown media data type " + data.getClass().getName()); + } } private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 5685fa43ee7..f890f1266ab 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,18 +22,18 @@ import java.util.Map; import java.util.Set; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.boot.context.properties.NestedConfigurationProperty; -import org.springframework.util.Assert; - import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + /** * The configuration information for a chat completions request. Completions support a * wide variety of tasks and generate text that continues from or "completes" provided @@ -206,129 +206,26 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected AzureOpenAiChatOptions options; - - public Builder() { - this.options = new AzureOpenAiChatOptions(); - } - - public Builder(AzureOpenAiChatOptions options) { - this.options = options; - } - - public Builder withDeploymentName(String deploymentName) { - this.options.deploymentName = deploymentName; - return this; - } - - public Builder withFrequencyPenalty(Double frequencyPenalty) { - this.options.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder withLogitBias(Map logitBias) { - this.options.logitBias = logitBias; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withN(Integer n) { - this.options.n = n; - return this; - } - - public Builder withPresencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withUser(String user) { - this.options.user = user; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withSeed(Long seed) { - this.options.seed = seed; - return this; - } - - public Builder withLogprobs(Boolean logprobs) { - this.options.logprobs = logprobs; - return this; - } - - public Builder withTopLogprobs(Integer topLogprobs) { - this.options.topLogProbs = topLogprobs; - return this; - } - - public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) { - this.options.enhancements = enhancements; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public AzureOpenAiChatOptions build() { - return this.options; - } - + public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) { + return builder().withDeploymentName(fromOptions.getDeploymentName()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty() : null) + .withLogitBias(fromOptions.getLogitBias()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withN(fromOptions.getN()) + .withPresencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withUser(fromOptions.getUser()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withSeed(fromOptions.getSeed()) + .withLogprobs(fromOptions.isLogprobs()) + .withTopLogprobs(fromOptions.getTopLogProbs()) + .withEnhancements(fromOptions.getEnhancements()) + .withToolContext(fromOptions.getToolContext()) + .build(); } @Override @@ -526,26 +423,129 @@ public AzureOpenAiChatOptions copy() { return fromOptions(this); } - public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) { - return builder().withDeploymentName(fromOptions.getDeploymentName()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty() : null) - .withLogitBias(fromOptions.getLogitBias()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withN(fromOptions.getN()) - .withPresencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withUser(fromOptions.getUser()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withSeed(fromOptions.getSeed()) - .withLogprobs(fromOptions.isLogprobs()) - .withTopLogprobs(fromOptions.getTopLogProbs()) - .withEnhancements(fromOptions.getEnhancements()) - .withToolContext(fromOptions.getToolContext()) - .build(); + public static class Builder { + + protected AzureOpenAiChatOptions options; + + public Builder() { + this.options = new AzureOpenAiChatOptions(); + } + + public Builder(AzureOpenAiChatOptions options) { + this.options = options; + } + + public Builder withDeploymentName(String deploymentName) { + this.options.deploymentName = deploymentName; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withLogitBias(Map logitBias) { + this.options.logitBias = logitBias; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withN(Integer n) { + this.options.n = n; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withUser(String user) { + this.options.user = user; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withSeed(Long seed) { + this.options.seed = seed; + return this; + } + + public Builder withLogprobs(Boolean logprobs) { + this.options.logprobs = logprobs; + return this; + } + + public Builder withTopLogprobs(Integer topLogprobs) { + this.options.topLogProbs = topLogprobs; + return this; + } + + public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) { + this.options.enhancements = enhancements; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public AzureOpenAiChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java index 17827585124..c7ca01b226b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; +import java.util.ArrayList; +import java.util.List; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.EmbeddingItem; import com.azure.ai.openai.models.Embeddings; import com.azure.ai.openai.models.EmbeddingsOptions; - import io.micrometer.observation.ObservationRegistry; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; @@ -41,9 +44,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.List; - /** * Azure Open AI Embedding Model implementation. * @@ -56,14 +56,14 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel { private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class); + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + private final OpenAIClient azureOpenAiClient; private final AzureOpenAiEmbeddingOptions defaultOptions; private final MetadataMode metadataMode; - private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); - /** * Observation registry used for instrumentation. */ diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java index 7713f95f633..e2e8f3e2494 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.List; import com.fasterxml.jackson.annotation.JsonIgnore; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -58,6 +60,61 @@ public static Builder builder() { return new Builder(); } + @Override + @JsonIgnore + public String getModel() { + return getDeploymentName(); + } + + @JsonIgnore + public void setModel(String model) { + setDeploymentName(model); + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + public String getDeploymentName() { + return this.deploymentName; + } + + public void setDeploymentName(String deploymentName) { + this.deploymentName = deploymentName; + } + + public String getInputType() { + return this.inputType; + } + + public void setInputType(String inputType) { + this.inputType = inputType; + } + + @Override + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public com.azure.ai.openai.models.EmbeddingsOptions toAzureOptions(List instructions) { + + var azureOptions = new com.azure.ai.openai.models.EmbeddingsOptions(instructions); + azureOptions.setModel(this.getDeploymentName()); + azureOptions.setUser(this.getUser()); + azureOptions.setInputType(this.getInputType()); + azureOptions.setDimensions(this.getDimensions()); + + return azureOptions; + } + public static class Builder { private final AzureOpenAiEmbeddingOptions options = new AzureOpenAiEmbeddingOptions(); @@ -125,59 +182,4 @@ public AzureOpenAiEmbeddingOptions build() { } - @Override - @JsonIgnore - public String getModel() { - return getDeploymentName(); - } - - @JsonIgnore - public void setModel(String model) { - setDeploymentName(model); - } - - public String getUser() { - return this.user; - } - - public void setUser(String user) { - this.user = user; - } - - public String getDeploymentName() { - return this.deploymentName; - } - - public void setDeploymentName(String deploymentName) { - this.deploymentName = deploymentName; - } - - public String getInputType() { - return this.inputType; - } - - public void setInputType(String inputType) { - this.inputType = inputType; - } - - @Override - public Integer getDimensions() { - return this.dimensions; - } - - public void setDimensions(Integer dimensions) { - this.dimensions = dimensions; - } - - public com.azure.ai.openai.models.EmbeddingsOptions toAzureOptions(List instructions) { - - var azureOptions = new com.azure.ai.openai.models.EmbeddingsOptions(instructions); - azureOptions.setModel(this.getDeploymentName()); - azureOptions.setUser(this.getUser()); - azureOptions.setInputType(this.getInputType()); - azureOptions.setDimensions(this.getDimensions()); - - return azureOptions; - } - } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java index 7d2b3dae380..9b1b466efac 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java @@ -1,5 +1,23 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai; +import java.util.List; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ImageGenerationOptions; import com.azure.ai.openai.models.ImageGenerationQuality; @@ -13,6 +31,7 @@ import com.fasterxml.jackson.databind.json.JsonMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.azure.openai.metadata.AzureOpenAiImageGenerationMetadata; import org.springframework.ai.azure.openai.metadata.AzureOpenAiImageResponseMetadata; import org.springframework.ai.image.Image; @@ -25,8 +44,6 @@ import org.springframework.ai.util.JacksonUtils; import org.springframework.util.Assert; -import java.util.List; - import static java.lang.String.format; /** @@ -68,22 +85,22 @@ public AzureOpenAiImageModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiImag } public AzureOpenAiImageOptions getDefaultOptions() { - return defaultOptions; + return this.defaultOptions; } @Override public ImageResponse call(ImagePrompt imagePrompt) { ImageGenerationOptions imageGenerationOptions = toOpenAiImageOptions(imagePrompt); String deploymentOrModelName = getDeploymentName(imagePrompt); - if (logger.isTraceEnabled()) { - logger.trace("Azure ImageGenerationOptions call {} with the following options : {} ", deploymentOrModelName, - toPrettyJson(imageGenerationOptions)); + if (this.logger.isTraceEnabled()) { + this.logger.trace("Azure ImageGenerationOptions call {} with the following options : {} ", + deploymentOrModelName, toPrettyJson(imageGenerationOptions)); } - var images = openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions); + var images = this.openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions); - if (logger.isTraceEnabled()) { - logger.trace("Azure ImageGenerations: {}", toPrettyJson(images)); + if (this.logger.isTraceEnabled()) { + this.logger.trace("Azure ImageGenerations: {}", toPrettyJson(images)); } List imageGenerations = images.getData().stream().map(entry -> { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java index be15fbfd10c..2e6d13c572f 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java @@ -1,12 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai; import java.util.Objects; import com.fasterxml.jackson.annotation.JsonInclude; -import org.springframework.ai.image.ImageOptions; - import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.image.ImageOptions; + /** * The configuration information for a image generation request. * @@ -89,9 +105,13 @@ public class AzureOpenAiImageOptions implements ImageOptions { @JsonProperty("user") private String user; + public static Builder builder() { + return new Builder(); + } + @Override public Integer getN() { - return n; + return this.n; } public void setN(Integer n) { @@ -100,7 +120,7 @@ public void setN(Integer n) { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -109,7 +129,7 @@ public void setModel(String model) { @Override public Integer getWidth() { - return width; + return this.width; } public void setWidth(Integer width) { @@ -119,7 +139,7 @@ public void setWidth(Integer width) { @Override public Integer getHeight() { - return height; + return this.height; } public void setHeight(Integer height) { @@ -129,7 +149,7 @@ public void setHeight(Integer height) { @Override public String getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(String responseFormat) { @@ -148,7 +168,7 @@ public void setSize(String size) { } public String getUser() { - return user; + return this.user; } public void setUser(String user) { @@ -156,7 +176,7 @@ public void setUser(String user) { } public String getQuality() { - return quality; + return this.quality; } public void setQuality(String quality) { @@ -165,7 +185,7 @@ public void setQuality(String quality) { @Override public String getStyle() { - return style; + return this.style; } public void setStyle(String style) { @@ -173,41 +193,66 @@ public void setStyle(String style) { } public String getDeploymentName() { - return deploymentName; + return this.deploymentName; } public void setDeploymentName(String deploymentName) { this.deploymentName = deploymentName; } - public static Builder builder() { - return new Builder(); - } - @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AzureOpenAiImageOptions that)) + } + if (!(o instanceof AzureOpenAiImageOptions that)) { return false; - return Objects.equals(n, that.n) && Objects.equals(model, that.model) - && Objects.equals(deploymentName, that.deploymentName) && Objects.equals(width, that.width) - && Objects.equals(height, that.height) && Objects.equals(quality, that.quality) - && Objects.equals(responseFormat, that.responseFormat) && Objects.equals(size, that.size) - && Objects.equals(style, that.style) && Objects.equals(user, that.user); + } + return Objects.equals(this.n, that.n) && Objects.equals(this.model, that.model) + && Objects.equals(this.deploymentName, that.deploymentName) && Objects.equals(this.width, that.width) + && Objects.equals(this.height, that.height) && Objects.equals(this.quality, that.quality) + && Objects.equals(this.responseFormat, that.responseFormat) && Objects.equals(this.size, that.size) + && Objects.equals(this.style, that.style) && Objects.equals(this.user, that.user); } @Override public int hashCode() { - return Objects.hash(n, model, deploymentName, width, height, quality, responseFormat, size, style, user); + return Objects.hash(this.n, this.model, this.deploymentName, this.width, this.height, this.quality, + this.responseFormat, this.size, this.style, this.user); } @Override public String toString() { - return "AzureOpenAiImageOptions{" + "n=" + n + ", model='" + model + '\'' + ", deploymentName='" - + deploymentName + '\'' + ", width=" + width + ", height=" + height + ", quality='" + quality + '\'' - + ", responseFormat='" + responseFormat + '\'' + ", size='" + size + '\'' + ", style='" + style + '\'' - + ", user='" + user + '\'' + '}'; + return "AzureOpenAiImageOptions{" + "n=" + this.n + ", model='" + this.model + '\'' + ", deploymentName='" + + this.deploymentName + '\'' + ", width=" + this.width + ", height=" + this.height + ", quality='" + + this.quality + '\'' + ", responseFormat='" + this.responseFormat + '\'' + ", size='" + this.size + + '\'' + ", style='" + this.style + '\'' + ", user='" + this.user + '\'' + '}'; + } + + public enum ImageModel { + + /** + * The latest DALL·E model released in Nov 2023. + */ + DALL_E_3("dall-e-3"), + + /** + * The previous DALL·E model released in Nov 2022. The 2nd iteration of DALL·E + * with more realistic, accurate, and 4x greater resolution images than the + * original model. + */ + DALL_E_2("dall-e-2"); + + private final String value; + + ImageModel(String model) { + this.value = model; + } + + public String getValue() { + return this.value; + } + } public static class Builder { @@ -219,75 +264,49 @@ private Builder() { } public Builder withN(Integer n) { - options.setN(n); + this.options.setN(n); return this; } public Builder withModel(String model) { - options.setModel(model); + this.options.setModel(model); return this; } public Builder withDeploymentName(String deploymentName) { - options.setDeploymentName(deploymentName); + this.options.setDeploymentName(deploymentName); return this; } public Builder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); + this.options.setResponseFormat(responseFormat); return this; } public Builder withWidth(Integer width) { - options.setWidth(width); + this.options.setWidth(width); return this; } public Builder withHeight(Integer height) { - options.setHeight(height); + this.options.setHeight(height); return this; } public Builder withUser(String user) { - options.setUser(user); + this.options.setUser(user); return this; } public AzureOpenAiImageOptions build() { - return options; + return this.options; } public Builder withStyle(String style) { - options.setStyle(style); + this.options.setStyle(style); return this; } } - public enum ImageModel { - - /** - * The latest DALL·E model released in Nov 2023. - */ - DALL_E_3("dall-e-3"), - - /** - * The previous DALL·E model released in Nov 2022. The 2nd iteration of DALL·E - * with more realistic, accurate, and 4x greater resolution images than the - * original model. - */ - DALL_E_2("dall-e-2"); - - private final String value; - - ImageModel(String model) { - this.value = model; - } - - public String getValue() { - return this.value; - } - - } - } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java index 31bcb745852..fd83532ec77 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; /** diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java index 1411817682e..82c1f57b5d1 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.lang.reflect.Constructor; @@ -49,6 +50,15 @@ */ public class MergeUtils { + private static final Class[] CHAT_COMPLETIONS_CONSTRUCTOR_ARG_TYPES = new Class[] { String.class, + OffsetDateTime.class, List.class, CompletionsUsage.class }; + + private static final Class[] chatChoiceConstructorArgumentTypes = new Class[] { + ChatChoiceLogProbabilityInfo.class, int.class, CompletionsFinishReason.class }; + + private static final Class[] chatResponseMessageConstructorArgumentTypes = new Class[] { ChatRole.class, + String.class }; + /** * Create a new instance of the given class using the constructor at the given index. * Can be used to create instances with private constructors. @@ -106,9 +116,6 @@ public static ChatCompletions emptyChatCompletions() { return chatCompletionsInstance; } - private static final Class[] CHAT_COMPLETIONS_CONSTRUCTOR_ARG_TYPES = new Class[] { String.class, - OffsetDateTime.class, List.class, CompletionsUsage.class }; - /** * Merge two ChatCompletions instances into a single ChatCompletions instance. * @param left the left ChatCompletions instance. @@ -158,9 +165,6 @@ public static ChatCompletions mergeChatCompletions(ChatCompletions left, ChatCom return instance; } - private static final Class[] chatChoiceConstructorArgumentTypes = new Class[] { - ChatChoiceLogProbabilityInfo.class, int.class, CompletionsFinishReason.class }; - /** * Merge two ChatChoice instances into a single ChatChoice instance. * @param left the left ChatChoice instance to merge. @@ -211,9 +215,6 @@ private static ChatChoice mergeChatChoice(ChatChoice left, ChatChoice right) { return instance; } - private static final Class[] chatResponseMessageConstructorArgumentTypes = new Class[] { ChatRole.class, - String.class }; - /** * Merge two ChatResponseMessage instances into a single ChatResponseMessage instance. * @param left the left ChatResponseMessage instance to merge. diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java index 488870bcbc1..75ba720b02c 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.aot; import com.azure.ai.openai.OpenAIAsyncClient; diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiAudioTranscriptionResponseMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiAudioTranscriptionResponseMetadata.java index f64a805a146..a55ecd604ed 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiAudioTranscriptionResponseMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiAudioTranscriptionResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.metadata; import org.springframework.ai.audio.transcription.AudioTranscriptionResponseMetadata; @@ -26,11 +27,15 @@ */ public class AzureOpenAiAudioTranscriptionResponseMetadata extends AudioTranscriptionResponseMetadata { - protected static final String AI_METADATA_STRING = "{ @type: %1$s }"; - public static final AzureOpenAiAudioTranscriptionResponseMetadata NULL = new AzureOpenAiAudioTranscriptionResponseMetadata() { + }; + protected static final String AI_METADATA_STRING = "{ @type: %1$s }"; + + protected AzureOpenAiAudioTranscriptionResponseMetadata() { + } + public static AzureOpenAiAudioTranscriptionResponseMetadata from( AzureOpenAiAudioTranscriptionOptions.StructuredResponse result) { Assert.notNull(result, "AzureOpenAI Transcription must not be null"); @@ -42,9 +47,6 @@ public static AzureOpenAiAudioTranscriptionResponseMetadata from(String result) return new AzureOpenAiAudioTranscriptionResponseMetadata(); } - protected AzureOpenAiAudioTranscriptionResponseMetadata() { - } - @Override public String toString() { return AI_METADATA_STRING.formatted(getClass().getName()); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java index 8ec132871bf..8fe0fa1e42b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.metadata; import com.azure.ai.openai.models.EmbeddingsUsage; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; @@ -27,11 +29,6 @@ */ public class AzureOpenAiEmbeddingUsage implements Usage { - public static AzureOpenAiEmbeddingUsage from(EmbeddingsUsage usage) { - Assert.notNull(usage, "EmbeddingsUsage must not be null"); - return new AzureOpenAiEmbeddingUsage(usage); - } - private final EmbeddingsUsage usage; public AzureOpenAiEmbeddingUsage(EmbeddingsUsage usage) { @@ -39,6 +36,11 @@ public AzureOpenAiEmbeddingUsage(EmbeddingsUsage usage) { this.usage = usage; } + public static AzureOpenAiEmbeddingUsage from(EmbeddingsUsage usage) { + Assert.notNull(usage, "EmbeddingsUsage must not be null"); + return new AzureOpenAiEmbeddingUsage(usage); + } + protected EmbeddingsUsage getUsage() { return this.usage; } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageGenerationMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageGenerationMetadata.java index 44b429e9f72..eecc94ef78a 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageGenerationMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageGenerationMetadata.java @@ -1,9 +1,25 @@ -package org.springframework.ai.azure.openai.metadata; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.springframework.ai.image.ImageGenerationMetadata; +package org.springframework.ai.azure.openai.metadata; import java.util.Objects; +import org.springframework.ai.image.ImageGenerationMetadata; + /** * Represents the metadata for image generation using Azure OpenAI. * @@ -19,25 +35,27 @@ public AzureOpenAiImageGenerationMetadata(String revisedPrompt) { } public String getRevisedPrompt() { - return revisedPrompt; + return this.revisedPrompt; } public String toString() { - return "AzureOpenAiImageGenerationMetadata{" + "revisedPrompt='" + revisedPrompt + '\'' + '}'; + return "AzureOpenAiImageGenerationMetadata{" + "revisedPrompt='" + this.revisedPrompt + '\'' + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AzureOpenAiImageGenerationMetadata that)) + } + if (!(o instanceof AzureOpenAiImageGenerationMetadata that)) { return false; - return Objects.equals(revisedPrompt, that.revisedPrompt); + } + return Objects.equals(this.revisedPrompt, that.revisedPrompt); } @Override public int hashCode() { - return Objects.hash(revisedPrompt); + return Objects.hash(this.revisedPrompt); } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java index 6d01d5cbb84..cdc24d0abdf 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java @@ -1,13 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai.metadata; +import java.util.Objects; + import com.azure.ai.openai.models.ImageGenerations; + import org.springframework.ai.image.ImageResponseMetadata; -import org.springframework.ai.model.MutableResponseMetadata; import org.springframework.util.Assert; -import java.util.HashMap; -import java.util.Objects; - /** * Represents metadata associated with an image response from the Azure OpenAI image * model. It provides additional information about the generative response from the Azure @@ -20,15 +35,15 @@ public class AzureOpenAiImageResponseMetadata extends ImageResponseMetadata { private final Long created; + protected AzureOpenAiImageResponseMetadata(Long created) { + this.created = created; + } + public static AzureOpenAiImageResponseMetadata from(ImageGenerations openAiImageResponse) { Assert.notNull(openAiImageResponse, "OpenAiImageResponse must not be null"); return new AzureOpenAiImageResponseMetadata(openAiImageResponse.getCreatedAt().toEpochSecond()); } - protected AzureOpenAiImageResponseMetadata(Long created) { - this.created = created; - } - @Override public Long getCreated() { return this.created; @@ -36,21 +51,23 @@ public Long getCreated() { @Override public String toString() { - return "AzureOpenAiImageResponseMetadata{" + "created=" + created + '}'; + return "AzureOpenAiImageResponseMetadata{" + "created=" + this.created + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AzureOpenAiImageResponseMetadata that)) + } + if (!(o instanceof AzureOpenAiImageResponseMetadata that)) { return false; - return Objects.equals(created, that.created); + } + return Objects.equals(this.created, that.created); } @Override public int hashCode() { - return Objects.hash(created); + return Objects.hash(this.created); } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java index 056d44eb03a..b0dd15d1367 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.metadata; import com.azure.ai.openai.models.ChatCompletions; @@ -30,6 +31,13 @@ */ public class AzureOpenAiUsage implements Usage { + private final CompletionsUsage usage; + + public AzureOpenAiUsage(CompletionsUsage usage) { + Assert.notNull(usage, "CompletionsUsage must not be null"); + this.usage = usage; + } + public static AzureOpenAiUsage from(ChatCompletions chatCompletions) { Assert.notNull(chatCompletions, "ChatCompletions must not be null"); return from(chatCompletions.getUsage()); @@ -39,13 +47,6 @@ public static AzureOpenAiUsage from(CompletionsUsage usage) { return new AzureOpenAiUsage(usage); } - private final CompletionsUsage usage; - - public AzureOpenAiUsage(CompletionsUsage usage) { - Assert.notNull(usage, "CompletionsUsage must not be null"); - this.usage = usage; - } - protected CompletionsUsage getUsage() { return this.usage; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index dbc6fa46d83..e686d17fa1c 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,10 +16,12 @@ package org.springframework.ai.azure.openai; -import com.azure.ai.openai.OpenAIClient; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; -import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; import org.junit.jupiter.api.Test; @@ -30,10 +32,6 @@ import org.springframework.ai.chat.prompt.Prompt; -import java.util.List; -import java.util.Map; -import java.util.stream.Stream; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -42,6 +40,11 @@ */ public class AzureChatCompletionsOptionsTests { + private static Stream providePresencePenaltyAndFrequencyPenaltyTest() { + return Stream.of(Arguments.of(0.0, 0.0), Arguments.of(0.0, 1.0), Arguments.of(1.0, 0.0), Arguments.of(1.0, 1.0), + Arguments.of(1.0, null), Arguments.of(null, 1.0), Arguments.of(null, null)); + } + @Test public void createRequestWithChatOptions() { @@ -132,11 +135,6 @@ public void createRequestWithChatOptions() { assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class); } - private static Stream providePresencePenaltyAndFrequencyPenaltyTest() { - return Stream.of(Arguments.of(0.0, 0.0), Arguments.of(0.0, 1.0), Arguments.of(1.0, 0.0), Arguments.of(1.0, 1.0), - Arguments.of(1.0, null), Arguments.of(null, 1.0), Arguments.of(null, null)); - } - @ParameterizedTest @MethodSource("providePresencePenaltyAndFrequencyPenaltyTest") public void createChatOptionsWithPresencePenaltyAndFrequencyPenalty(Double presencePenalty, diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java index 18fe0e56af1..62782438526 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.List; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModelIT.java index a8a7d44ae15..e3fbcc92a2d 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModelIT.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai; import com.azure.ai.openai.OpenAIClient; @@ -6,6 +22,7 @@ import com.azure.core.credential.AzureKeyCredential; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.beans.factory.annotation.Autowired; @@ -38,8 +55,9 @@ void transcriptionTest() { .withResponseFormat(AzureOpenAiAudioTranscriptionOptions.TranscriptResponseFormat.TEXT) .withTemperature(0f) .build(); - AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest); + AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, + transcriptionOptions); + AudioTranscriptionResponse response = this.transcriptionModel.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } @@ -54,8 +72,9 @@ void transcriptionTestWithOptions() { .withTemperature(0f) .withResponseFormat(responseFormat) .build(); - AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest); + AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, + transcriptionOptions); + AudioTranscriptionResponse response = this.transcriptionModel.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java index 4b931b799d8..8babb2e0c8c 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,17 @@ package org.springframework.ai.azure.openai; -import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.OpenAIServiceVersion; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.http.policy.HttpLogOptions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; @@ -35,13 +37,10 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; - -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.ai.openai.OpenAIServiceVersion; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.core.http.policy.HttpLogOptions; import org.springframework.core.io.Resource; -import reactor.core.publisher.Flux; + +import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Soby Chacko @@ -57,16 +56,13 @@ public class AzureOpenAiChatClientIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { - } - @Test void call() { // @formatter:off - ChatResponse response = chatClient.prompt() + ChatResponse response = this.chatClient.prompt() .advisors(new SimpleLoggerAdvisor()) - .system(s -> s.text(systemTextResource) + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -84,7 +80,7 @@ void beanStreamOutputConverterRecords() { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); // @formatter:off - Flux chatResponse = chatClient + Flux chatResponse = this.chatClient .prompt() .advisors(new SimpleLoggerAdvisor()) .user(u -> u @@ -117,12 +113,12 @@ void streamingAndImperativeResponsesContainIdenticalRelevantResults() { + "List them with a numerical index. Do not use any abbreviations in state or capitals."; // Imperative call - String rawDataFromImperativeCall = chatClient.prompt(prompt).call().content(); + String rawDataFromImperativeCall = this.chatClient.prompt(prompt).call().content(); String imperativeStatesData = extractStatesData(rawDataFromImperativeCall); String formattedImperativeResponse = formatResponse(imperativeStatesData); // Streaming call - String stitchedResponseFromStream = chatClient.prompt(prompt) + String stitchedResponseFromStream = this.chatClient.prompt(prompt) .stream() .content() .collectList() @@ -150,6 +146,10 @@ private String formatResponse(String response) { return String.join("\n", Arrays.stream(response.split("\n")).map(String::strip).toArray(String[]::new)); } + record ActorsFilms(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index aaad145b5d6..ffa657aa9f0 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; +import java.io.IOException; +import java.net.URL; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.OpenAIServiceVersion; import com.azure.core.credential.AzureKeyCredential; @@ -23,6 +32,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -44,14 +54,6 @@ import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import java.io.IOException; -import java.net.URL; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; import static org.assertj.core.api.Assertions.assertThat; @@ -77,7 +79,7 @@ void roleTest() { UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates."); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -96,12 +98,12 @@ void testMessageHistory() { Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(), new UserMessage("Repeat the last assistant message."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); System.out.println(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); @@ -120,7 +122,7 @@ void listOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); List list = outputConverter.convert(generation.getOutput().getContent()); assertThat(list).hasSize(5); @@ -139,7 +141,7 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -158,7 +160,7 @@ void beanOutputConverter() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isNotNull(); @@ -176,7 +178,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -197,7 +199,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -221,7 +223,7 @@ void multiModalityImageUrl() throws IOException { URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) .call() @@ -239,7 +241,7 @@ void multiModalityImageResource() { Resource resource = new ClassPathResource("multimodality/multimodal.test.png"); // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, resource)) .call() @@ -252,9 +254,11 @@ void multiModalityImageResource() { } record ActorsFilms(String actor, List movies) { + } record ActorsFilmsRecord(String actor, List movies) { + } @SpringBootConfiguration diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java index 52a184eb6a2..2e194ea2d10 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.azure.openai; -import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.azure.openai; import java.util.List; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.OpenAIServiceVersion; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.http.policy.HttpLogOptions; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; @@ -37,13 +42,8 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.ai.openai.OpenAIServiceVersion; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.core.http.policy.HttpLogOptions; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import reactor.core.publisher.Flux; +import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Soby Chacko @@ -54,14 +54,14 @@ class AzureOpenAiChatModelObservationIT { @Autowired - private AzureOpenAiChatModel chatModel; + TestObservationRegistry observationRegistry; @Autowired - TestObservationRegistry observationRegistry; + private AzureOpenAiChatModel chatModel; @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -78,7 +78,7 @@ void observationForImperativeChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -102,7 +102,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); assertThat(responses).hasSizeGreaterThan(10); @@ -123,7 +123,7 @@ void observationForStreamingChatOperation() { private void validate(ChatResponseMetadata responseMetadata, boolean checkModel) { - TestObservationRegistryAssert.That that = TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.That that = TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java index 0ee62a147e5..c18114c67dc 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.List; @@ -22,6 +23,7 @@ import com.azure.core.credential.AzureKeyCredential; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.beans.factory.annotation.Autowired; @@ -41,18 +43,18 @@ class AzureOpenAiEmbeddingModelIT { @Test void singleEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - System.out.println(embeddingModel.dimensions()); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + System.out.println(this.embeddingModel.dimensions()); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); @@ -60,7 +62,7 @@ void batchEmbedding() { assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } @SpringBootConfiguration diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java index db7b05dfce1..dc94e4a94e3 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.azure.openai; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.azure.openai; import java.util.List; +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -35,12 +40,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.core.credential.AzureKeyCredential; - -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation in {@link AzureOpenAiEmbeddingModel}. @@ -69,13 +69,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java index 124cd4485b5..48df6e123e8 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.azure.openai; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.ai.azure.openai; import java.io.IOException; import java.io.UnsupportedEncodingException; @@ -26,8 +25,14 @@ import java.util.Queue; import java.util.concurrent.ConcurrentLinkedDeque; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import okio.Buffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.InitializingBean; @@ -43,11 +48,7 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import okhttp3.mockwebserver.Dispatcher; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import okhttp3.mockwebserver.RecordedRequest; -import okio.Buffer; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** * Spring {@link Configuration} for AI integration testing using mock objects. @@ -205,22 +206,22 @@ private String getBody(MockHttpServletResponse response) { */ static class MockWebServerFactoryBean implements FactoryBean, InitializingBean, DisposableBean { - private Dispatcher dispatcher; - private final Logger logger = LoggerFactory.getLogger(getClass().getName()); - private MockWebServer mockWebServer; - private final Queue queuedResponses = new ConcurrentLinkedDeque<>(); - public void setDispatcher(@Nullable Dispatcher dispatcher) { - this.dispatcher = dispatcher; - } + private Dispatcher dispatcher; + + private MockWebServer mockWebServer; protected Optional getDispatcher() { return Optional.ofNullable(this.dispatcher); } + public void setDispatcher(@Nullable Dispatcher dispatcher) { + this.dispatcher = dispatcher; + } + protected Logger getLogger() { return this.logger; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java index e4a12a846c4..1c0a84cade3 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; -import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; +import okhttp3.HttpUrl; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockWebServer; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; @@ -24,10 +27,6 @@ import org.springframework.context.annotation.Profile; import org.springframework.test.web.servlet.MockMvc; -import okhttp3.HttpUrl; -import okhttp3.mockwebserver.Dispatcher; -import okhttp3.mockwebserver.MockWebServer; - /** * {@link SpringBootConfiguration} for testing {@literal Azure OpenAI's} API using mock * objects. diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHintsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHintsTests.java index eaec2cbdddd..8984fe5a3ee 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHintsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.aot; import java.util.Set; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index 635407cd766..19bc2c7308f 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.function; import java.util.ArrayList; @@ -22,21 +23,21 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; -import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.beans.factory.annotation.Autowired; @@ -44,7 +45,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; import static org.assertj.core.api.Assertions.assertThat; @@ -69,7 +69,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) + .withDeploymentName(this.selectedModel) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") @@ -77,7 +77,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -93,7 +93,7 @@ void functionCallSequentialTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) + .withDeploymentName(this.selectedModel) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") @@ -101,7 +101,7 @@ void functionCallSequentialTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -115,7 +115,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) + .withDeploymentName(this.selectedModel) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") @@ -123,7 +123,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); final var counter = new AtomicInteger(); String content = response.doOnEach(listSignal -> counter.getAndIncrement()) @@ -152,7 +152,7 @@ void functionCallSequentialAndStreamTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) + .withDeploymentName(this.selectedModel) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") @@ -160,7 +160,7 @@ void functionCallSequentialAndStreamTest() { .build())) .build(); - var response = chatModel.stream(new Prompt(messages, promptOptions)); + var response = this.chatModel.stream(new Prompt(messages, promptOptions)); final var counter = new AtomicInteger(); String content = response.doOnEach(listSignal -> counter.getAndIncrement()) @@ -182,6 +182,16 @@ void functionCallSequentialAndStreamTest() { @SpringBootConfiguration public static class TestConfiguration { + public static String getDeploymentName() { + String deploymentName = System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"); + if (StringUtils.hasText(deploymentName)) { + return deploymentName; + } + else { + return "gpt-4o"; + } + } + @Bean public OpenAIClientBuilder openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) @@ -199,16 +209,6 @@ public String selectedModel() { return Optional.ofNullable(System.getenv("AZURE_OPENAI_MODEL")).orElse(getDeploymentName()); } - public static String getDeploymentName() { - String deploymentName = System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"); - if (StringUtils.hasText(deploymentName)) { - return deploymentName; - } - else { - return "gpt-4o"; - } - } - } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java index 92747ed3023..e122e5f690d 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,29 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.function; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Christian Tzolov */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -63,28 +71,23 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java index f57dfb6a706..a6efd3c224a 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai.image; import com.azure.ai.openai.OpenAIClient; @@ -6,6 +22,7 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.azure.openai.AzureOpenAiImageModel; import org.springframework.ai.azure.openai.AzureOpenAiImageOptions; import org.springframework.ai.azure.openai.metadata.AzureOpenAiImageGenerationMetadata; @@ -39,7 +56,7 @@ void imageAsUrlTest() { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java index a8745010331..d962252ac71 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.metadata; import java.nio.charset.StandardCharsets; @@ -25,14 +26,14 @@ import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.MockAzureOpenAiTestConfiguration; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.EmptyRateLimit; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; diff --git a/models/spring-ai-bedrock/pom.xml b/models/spring-ai-bedrock/pom.xml index 51c458c6514..29f48b7db01 100644 --- a/models/spring-ai-bedrock/pom.xml +++ b/models/spring-ai-bedrock/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java index 61389fffaac..6394090443a 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock; import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetrics; @@ -27,10 +28,6 @@ */ public class BedrockUsage implements Usage { - public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) { - return new BedrockUsage(usage); - } - private final AmazonBedrockInvocationMetrics usage; protected BedrockUsage(AmazonBedrockInvocationMetrics usage) { @@ -38,6 +35,10 @@ protected BedrockUsage(AmazonBedrockInvocationMetrics usage) { this.usage = usage; } + public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) { + return new BedrockUsage(usage); + } + protected AmazonBedrockInvocationMetrics getUsage() { return this.usage; } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java index 95abde87268..001b2fd9896 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock; import java.util.List; @@ -33,12 +34,12 @@ public class MessageToPromptConverter { private static final String ASSISTANT_PROMPT = "Assistant:"; + private final String lineSeparator; + private String humanPrompt = HUMAN_PROMPT; private String assistantPrompt = ASSISTANT_PROMPT; - private final String lineSeparator; - private MessageToPromptConverter(String lineSeparator) { this.lineSeparator = lineSeparator; } @@ -84,9 +85,9 @@ protected String messageToString(Message message) { case SYSTEM: return message.getContent(); case USER: - return humanPrompt + " " + message.getContent(); + return this.humanPrompt + " " + message.getContent(); case ASSISTANT: - return assistantPrompt + " " + message.getContent(); + return this.assistantPrompt + " " + message.getContent(); case TOOL: throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java index 7bdc15e2d26..5625a55a11d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic; import java.util.List; @@ -20,11 +21,10 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * @author Christian Tzolov * @author Thomas Vitale @@ -75,44 +75,14 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final AnthropicChatOptions options = new AnthropicChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withMaxTokensToSample(Integer maxTokensToSample) { - this.options.setMaxTokensToSample(maxTokensToSample); - return this; - } - - public Builder withTopK(Integer topK) { - this.options.setTopK(topK); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.setStopSequences(stopSequences); - return this; - } - - public Builder withAnthropicVersion(String anthropicVersion) { - this.options.setAnthropicVersion(anthropicVersion); - return this; - } - - public AnthropicChatOptions build() { - return this.options; - } - + public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withMaxTokensToSample(fromOptions.getMaxTokensToSample()) + .withTopK(fromOptions.getTopK()) + .withTopP(fromOptions.getTopP()) + .withStopSequences(fromOptions.getStopSequences()) + .withAnthropicVersion(fromOptions.getAnthropicVersion()) + .build(); } @Override @@ -201,14 +171,44 @@ public AnthropicChatOptions copy() { return fromOptions(this); } - public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withMaxTokensToSample(fromOptions.getMaxTokensToSample()) - .withTopK(fromOptions.getTopK()) - .withTopP(fromOptions.getTopP()) - .withStopSequences(fromOptions.getStopSequences()) - .withAnthropicVersion(fromOptions.getAnthropicVersion()) - .build(); + public static class Builder { + + private final AnthropicChatOptions options = new AnthropicChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withMaxTokensToSample(Integer maxTokensToSample) { + this.options.setMaxTokensToSample(maxTokensToSample); + return this; + } + + public Builder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withAnthropicVersion(String anthropicVersion) { + this.options.setAnthropicVersion(anthropicVersion); + return this; + } + + public AnthropicChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java index c4321c37459..f5f1f91be0c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,22 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic; import java.util.List; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java index c1235456b62..074c036352b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic.api; import java.time.Duration; @@ -118,6 +119,54 @@ public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credential // Anthropic Claude models: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html + @Override + public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { + Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); + return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); + } + + @Override + public Flux chatCompletionStream(AnthropicChatRequest anthropicRequest) { + Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); + return this.internalInvocationStream(anthropicRequest, AnthropicChatResponse.class); + } + + /** + * Anthropic models version. + */ + public enum AnthropicChatModel implements ChatModelDescription { + /** + * anthropic.claude-instant-v1 + */ + CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), + /** + * anthropic.claude-v2 + */ + CLAUDE_V2("anthropic.claude-v2"), + /** + * anthropic.claude-v2:1 + */ + CLAUDE_V21("anthropic.claude-v2:1"); + + private final String id; + + AnthropicChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * AnthropicChatRequest encapsulates the request parameters for the Anthropic chat model. * https://docs.anthropic.com/claude/reference/complete_post @@ -196,13 +245,13 @@ public Builder withAnthropicVersion(String anthropicVersion) { public AnthropicChatRequest build() { return new AnthropicChatRequest( - prompt, - temperature, - maxTokensToSample, - topK, - topP, - stopSequences, - anthropicVersion + this.prompt, + this.temperature, + this.maxTokensToSample, + this.topK, + this.topP, + this.stopSequences, + this.anthropicVersion ); } } @@ -225,53 +274,5 @@ public record AnthropicChatResponse( @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { } - /** - * Anthropic models version. - */ - public enum AnthropicChatModel implements ChatModelDescription { - /** - * anthropic.claude-instant-v1 - */ - CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), - /** - * anthropic.claude-v2 - */ - CLAUDE_V2("anthropic.claude-v2"), - /** - * anthropic.claude-v2:1 - */ - CLAUDE_V21("anthropic.claude-v2:1"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - AnthropicChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - } - - @Override - public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); - } - - @Override - public Flux chatCompletionStream(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocationStream(anthropicRequest, AnthropicChatResponse.class); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java index 45927c91139..86d8ad67b77 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.ChatOptions; -import java.util.List; +import org.springframework.ai.chat.prompt.ChatOptions; /** * @author Ben Middleton @@ -74,44 +76,14 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final Anthropic3ChatOptions options = new Anthropic3ChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.setMaxTokens(maxTokens); - return this; - } - - public Builder withTopK(Integer topK) { - this.options.setTopK(topK); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.setStopSequences(stopSequences); - return this; - } - - public Builder withAnthropicVersion(String anthropicVersion) { - this.options.setAnthropicVersion(anthropicVersion); - return this; - } - - public Anthropic3ChatOptions build() { - return this.options; - } - + public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withTopK(fromOptions.getTopK()) + .withTopP(fromOptions.getTopP()) + .withStopSequences(fromOptions.getStopSequences()) + .withAnthropicVersion(fromOptions.getAnthropicVersion()) + .build(); } @Override @@ -190,14 +162,44 @@ public Anthropic3ChatOptions copy() { return fromOptions(this); } - public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withTopK(fromOptions.getTopK()) - .withTopP(fromOptions.getTopP()) - .withStopSequences(fromOptions.getStopSequences()) - .withAnthropicVersion(fromOptions.getAnthropicVersion()) - .build(); + public static class Builder { + + private final Anthropic3ChatOptions options = new Anthropic3ChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withAnthropicVersion(String anthropicVersion) { + this.options.setAnthropicVersion(anthropicVersion); + return this; + } + + public Anthropic3ChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index 4bef201bba3..deaa01f1304 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3; import java.util.ArrayList; @@ -21,11 +22,6 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.DefaultUsage; -import org.springframework.ai.chat.metadata.Usage; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; @@ -35,13 +31,18 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java index d9f9cf672eb..e6e8b96113f 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3.api; +import java.time.Duration; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.model.ChatModelDescription; import org.springframework.util.Assert; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import java.time.Duration; -import java.util.List; /** * Based on Bedrock's chatCompletionStream(AnthropicChatRequest anthropicRequest) { + Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); + return this.internalInvocationStream(anthropicRequest, AnthropicChatStreamingResponse.class); + } + + /** + * Anthropic models version. + */ + public enum AnthropicChatModel implements ChatModelDescription { + + /** + * anthropic.claude-instant-v1 + */ + CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), + /** + * anthropic.claude-v2 + */ + CLAUDE_V2("anthropic.claude-v2"), + /** + * anthropic.claude-v2:1 + */ + CLAUDE_V21("anthropic.claude-v2:1"), + /** + * anthropic.claude-3-sonnet-20240229-v1:0 + */ + CLAUDE_V3_SONNET("anthropic.claude-3-sonnet-20240229-v1:0"), + /** + * anthropic.claude-3-haiku-20240307-v1:0 + */ + CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"), + /** + * anthropic.claude-3-opus-20240229-v1:0 + */ + CLAUDE_V3_OPUS("anthropic.claude-3-opus-20240229-v1:0"), + /** + * anthropic.claude-3-5-sonnet-20240620-v1:0 + */ + CLAUDE_V3_5_SONNET("anthropic.claude-3-5-sonnet-20240620-v1:0"), + /** + * anthropic.claude-3-5-sonnet-20241022-v2:0 + */ + CLAUDE_V3_5_SONNET_V2("anthropic.claude-3-5-sonnet-20241022-v2:0"); + + private final String id; + + AnthropicChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + + } + /** * AnthropicChatRequest encapsulates the request parameters for the Anthropic messages model. * https://docs.anthropic.com/claude/reference/messages_post @@ -208,14 +280,14 @@ public Builder withAnthropicVersion(String anthropicVersion) { public AnthropicChatRequest build() { return new AnthropicChatRequest( - messages, - system, - temperature, - maxTokens, - topK, - topP, - stopSequences, - anthropicVersion + this.messages, + this.system, + this.temperature, + this.maxTokens, + this.topK, + this.topP, + this.stopSequences, + this.anthropicVersion ); } } @@ -286,7 +358,9 @@ public record Source( // @formatter:off public Source(String mediaType, String data) { this("base64", mediaType, data); } + } + } /** @@ -317,6 +391,7 @@ public enum Role { ASSISTANT } + } /** @@ -329,6 +404,7 @@ public enum Role { @JsonInclude(Include.NON_NULL) public record AnthropicUsage(@JsonProperty("input_tokens") Integer inputTokens, @JsonProperty("output_tokens") Integer outputTokens) { + } /** @@ -356,6 +432,7 @@ public record AnthropicChatResponse(// formatter:off @JsonProperty("stop_reason") String stopReason, @JsonProperty("stop_sequence") String stopSequence, @JsonProperty("usage") AnthropicUsage usage, @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { // formatter:on + } /** @@ -432,77 +509,9 @@ public enum StreamingType { @JsonInclude(Include.NON_NULL) public record Delta(@JsonProperty("type") String type, @JsonProperty("text") String text, @JsonProperty("stop_reason") String stopReason, @JsonProperty("stop_sequence") String stopSequence) { - } - } - - /** - * Anthropic models version. - */ - public enum AnthropicChatModel implements ChatModelDescription { - - /** - * anthropic.claude-instant-v1 - */ - CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), - /** - * anthropic.claude-v2 - */ - CLAUDE_V2("anthropic.claude-v2"), - /** - * anthropic.claude-v2:1 - */ - CLAUDE_V21("anthropic.claude-v2:1"), - /** - * anthropic.claude-3-sonnet-20240229-v1:0 - */ - CLAUDE_V3_SONNET("anthropic.claude-3-sonnet-20240229-v1:0"), - /** - * anthropic.claude-3-haiku-20240307-v1:0 - */ - CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"), - /** - * anthropic.claude-3-opus-20240229-v1:0 - */ - CLAUDE_V3_OPUS("anthropic.claude-3-opus-20240229-v1:0"), - /** - * anthropic.claude-3-5-sonnet-20240620-v1:0 - */ - CLAUDE_V3_5_SONNET("anthropic.claude-3-5-sonnet-20240620-v1:0"), - /** - * anthropic.claude-3-5-sonnet-20241022-v2:0 - */ - CLAUDE_V3_5_SONNET_V2("anthropic.claude-3-5-sonnet-20241022-v2:0"); - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; } - AnthropicChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - - } - - @Override - public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); - } - - @Override - public Flux chatCompletionStream(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocationStream(anthropicRequest, AnthropicChatStreamingResponse.class); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java index 7db24b3b8c6..b6f93d3b819 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.aot; import org.springframework.ai.bedrock.anthropic.AnthropicChatOptions; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java index 24a383adac0..11897200c28 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -175,24 +175,6 @@ public Region getRegion() { return this.region; } - /** - * Encapsulates the metrics about the model invocation. - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html - * - * @param inputTokenCount The number of tokens in the input prompt. - * @param firstByteLatency The time in milliseconds between the request being sent and the first byte of the - * response being received. - * @param outputTokenCount The number of tokens in the generated text. - * @param invocationLatency The time in milliseconds between the request being sent and the response being received. - */ - @JsonInclude(Include.NON_NULL) - public record AmazonBedrockInvocationMetrics( - @JsonProperty("inputTokenCount") Long inputTokenCount, - @JsonProperty("firstByteLatency") Long firstByteLatency, - @JsonProperty("outputTokenCount") Long outputTokenCount, - @JsonProperty("invocationLatency") Long invocationLatency) { - } - /** * Compute the embedding for the given text. * @@ -337,5 +319,23 @@ protected Flux internalInvocationStream(I request, Class clazz) { return eventSink.asFlux(); } + + /** + * Encapsulates the metrics about the model invocation. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html + * + * @param inputTokenCount The number of tokens in the input prompt. + * @param firstByteLatency The time in milliseconds between the request being sent and the first byte of the + * response being received. + * @param outputTokenCount The number of tokens in the generated text. + * @param invocationLatency The time in milliseconds between the request being sent and the response being received. + */ + @JsonInclude(Include.NON_NULL) + public record AmazonBedrockInvocationMetrics( + @JsonProperty("inputTokenCount") Long inputTokenCount, + @JsonProperty("firstByteLatency") Long firstByteLatency, + @JsonProperty("outputTokenCount") Long outputTokenCount, + @JsonProperty("invocationLatency") Long invocationLatency) { + } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java index e9895fc1db1..4d235a2b264 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.util.List; @@ -24,13 +25,13 @@ import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.util.Assert; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java index 04f67f282ee..8e5e0a6898e 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.util.List; @@ -85,59 +86,17 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final BedrockCohereChatOptions options = new BedrockCohereChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withTopK(Integer topK) { - this.options.setTopK(topK); - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.setMaxTokens(maxTokens); - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.setStopSequences(stopSequences); - return this; - } - - public Builder withReturnLikelihoods(ReturnLikelihoods returnLikelihoods) { - this.options.setReturnLikelihoods(returnLikelihoods); - return this; - } - - public Builder withNumGenerations(Integer numGenerations) { - this.options.setNumGenerations(numGenerations); - return this; - } - - public Builder withLogitBias(LogitBias logitBias) { - this.options.setLogitBias(logitBias); - return this; - } - - public Builder withTruncate(Truncate truncate) { - this.options.setTruncate(truncate); - return this; - } - - public BedrockCohereChatOptions build() { - return this.options; - } - + public static BedrockCohereChatOptions fromOptions(BedrockCohereChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withStopSequences(fromOptions.getStopSequences()) + .withReturnLikelihoods(fromOptions.getReturnLikelihoods()) + .withNumGenerations(fromOptions.getNumGenerations()) + .withLogitBias(fromOptions.getLogitBias()) + .withTruncate(fromOptions.getTruncate()) + .build(); } @Override @@ -240,17 +199,59 @@ public BedrockCohereChatOptions copy() { return fromOptions(this); } - public static BedrockCohereChatOptions fromOptions(BedrockCohereChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withStopSequences(fromOptions.getStopSequences()) - .withReturnLikelihoods(fromOptions.getReturnLikelihoods()) - .withNumGenerations(fromOptions.getNumGenerations()) - .withLogitBias(fromOptions.getLogitBias()) - .withTruncate(fromOptions.getTruncate()) - .build(); + public static class Builder { + + private final BedrockCohereChatOptions options = new BedrockCohereChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withReturnLikelihoods(ReturnLikelihoods returnLikelihoods) { + this.options.setReturnLikelihoods(returnLikelihoods); + return this; + } + + public Builder withNumGenerations(Integer numGenerations) { + this.options.setNumGenerations(numGenerations); + return this; + } + + public Builder withLogitBias(LogitBias logitBias) { + this.options.setLogitBias(logitBias); + return this; + } + + public Builder withTruncate(Truncate truncate) { + this.options.setTruncate(truncate); + return this; + } + + public BedrockCohereChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java index c070f250656..c34335d8b44 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.util.List; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java index 068d704545c..57e0de302bd 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -52,26 +53,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private BedrockCohereEmbeddingOptions options = new BedrockCohereEmbeddingOptions(); - - public Builder withInputType(InputType inputType) { - this.options.setInputType(inputType); - return this; - } - - public Builder withTruncate(Truncate truncate) { - this.options.setTruncate(truncate); - return this; - } - - public BedrockCohereEmbeddingOptions build() { - return this.options; - } - - } - public InputType getInputType() { return this.inputType; } @@ -100,4 +81,24 @@ public Integer getDimensions() { return null; } + public static class Builder { + + private BedrockCohereEmbeddingOptions options = new BedrockCohereEmbeddingOptions(); + + public Builder withInputType(InputType inputType) { + this.options.setInputType(inputType); + return this; + } + + public Builder withTruncate(Truncate truncate) { + this.options.setTruncate(truncate); + return this; + } + + public BedrockCohereEmbeddingOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java index 9feb62eeb46..454bd8aed4f 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -109,6 +109,52 @@ public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPr super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public CohereChatResponse chatCompletion(CohereChatRequest request) { + Assert.isTrue(!request.stream(), "The request must be configured to return the complete response!"); + return this.internalInvocation(request, CohereChatResponse.class); + } + + @Override + public Flux chatCompletionStream(CohereChatRequest request) { + Assert.isTrue(request.stream(), "The request must be configured to stream the response!"); + return this.internalInvocationStream(request, CohereChatResponse.Generation.class); + } + + /** + * Cohere models version. + */ + public enum CohereChatModel implements ChatModelDescription { + + /** + * cohere.command-light-text-v14 + */ + COHERE_COMMAND_LIGHT_V14("cohere.command-light-text-v14"), + + /** + * cohere.command-text-v14 + */ + COHERE_COMMAND_V14("cohere.command-text-v14"); + + private final String id; + + CohereChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * CohereChatRequest encapsulates the request parameters for the Cohere command model. * @@ -143,15 +189,12 @@ public record CohereChatRequest( @JsonProperty("truncate") Truncate truncate) { /** - * Prevents the model from generating unwanted tokens or incentivize the model to include desired tokens. - * - * @param token The token likelihoods. - * @param bias A float between -10 and 10. + * Get CohereChatRequest builder. + * @param prompt compulsory request prompt parameter. + * @return CohereChatRequest builder. */ - @JsonInclude(Include.NON_NULL) - public record LogitBias( - @JsonProperty("token") String token, - @JsonProperty("bias") Float bias) { + public static Builder builder(String prompt) { + return new Builder(prompt); } /** @@ -192,12 +235,15 @@ public enum Truncate { } /** - * Get CohereChatRequest builder. - * @param prompt compulsory request prompt parameter. - * @return CohereChatRequest builder. + * Prevents the model from generating unwanted tokens or incentivize the model to include desired tokens. + * + * @param token The token likelihoods. + * @param bias A float between -10 and 10. */ - public static Builder builder(String prompt) { - return new Builder(prompt); + @JsonInclude(Include.NON_NULL) + public record LogitBias( + @JsonProperty("token") String token, + @JsonProperty("bias") Float bias) { } /** @@ -272,17 +318,17 @@ public Builder withTruncate(Truncate truncate) { public CohereChatRequest build() { return new CohereChatRequest( - prompt, - temperature, - topP, - topK, - maxTokens, - stopSequences, - returnLikelihoods, - stream, - numGenerations, - logitBias, - truncate + this.prompt, + this.temperature, + this.topP, + this.topK, + this.maxTokens, + this.stopSequences, + this.returnLikelihoods, + this.stream, + this.numGenerations, + this.logitBias, + this.truncate ); } } @@ -331,16 +377,6 @@ public record Generation( @JsonProperty("index") Integer index, @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { - /** - * @param token The token. - * @param likelihood The likelihood of the token. - */ - @JsonInclude(Include.NON_NULL) - public record TokenLikelihood( - @JsonProperty("token") String token, - @JsonProperty("likelihood") Float likelihood) { - } - /** * The reason the response finished being generated. */ @@ -363,53 +399,17 @@ public enum FinishReason { */ ERROR_TOXIC } - } - } - - /** - * Cohere models version. - */ - public enum CohereChatModel implements ChatModelDescription { - - /** - * cohere.command-light-text-v14 - */ - COHERE_COMMAND_LIGHT_V14("cohere.command-light-text-v14"), - /** - * cohere.command-text-v14 - */ - COHERE_COMMAND_V14("cohere.command-text-v14"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - CohereChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; + /** + * @param token The token. + * @param likelihood The likelihood of the token. + */ + @JsonInclude(Include.NON_NULL) + public record TokenLikelihood( + @JsonProperty("token") String token, + @JsonProperty("likelihood") Float likelihood) { + } } } - - @Override - public CohereChatResponse chatCompletion(CohereChatRequest request) { - Assert.isTrue(!request.stream(), "The request must be configured to return the complete response!"); - return this.internalInvocation(request, CohereChatResponse.class); - } - - @Override - public Flux chatCompletionStream(CohereChatRequest request) { - Assert.isTrue(request.stream(), "The request must be configured to stream the response!"); - return this.internalInvocationStream(request, CohereChatResponse.Generation.class); - } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java index 30938fe9a3a..e69f229d6d8 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -109,6 +109,39 @@ public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credenti super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public CohereEmbeddingResponse embedding(CohereEmbeddingRequest request) { + return this.internalInvocation(request, CohereEmbeddingResponse.class); + } + + /** + * Cohere Embedding model ids. https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html + */ + public enum CohereEmbeddingModel { + /** + * cohere.embed-multilingual-v3 + */ + COHERE_EMBED_MULTILINGUAL_V1("cohere.embed-multilingual-v3"), + /** + * cohere.embed-english-v3 + */ + COHERE_EMBED_ENGLISH_V3("cohere.embed-english-v3"); + + private final String id; + + CohereEmbeddingModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + } + /** * The Cohere Embed model request. * @@ -190,38 +223,5 @@ public record CohereEmbeddingResponse( @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { } - /** - * Cohere Embedding model ids. https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html - */ - public enum CohereEmbeddingModel { - /** - * cohere.embed-multilingual-v3 - */ - COHERE_EMBED_MULTILINGUAL_V1("cohere.embed-multilingual-v3"), - /** - * cohere.embed-english-v3 - */ - COHERE_EMBED_ENGLISH_V3("cohere.embed-english-v3"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return this.id; - } - - CohereEmbeddingModel(String value) { - this.id = value; - } - - } - - @Override - public CohereEmbeddingResponse embedding(CohereEmbeddingRequest request) { - return this.internalInvocation(request, CohereEmbeddingResponse.class); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java index 7cb9bf0ac36..ab463b9116f 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,10 +19,10 @@ import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; @@ -57,6 +57,10 @@ public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi) { .build()); } + public static Builder builder(Ai21Jurassic2ChatBedrockApi chatApi) { + return new Builder(chatApi); + } + @Override public ChatResponse call(Prompt prompt) { var request = createRequest(prompt); @@ -88,8 +92,9 @@ private Ai21Jurassic2ChatRequest createRequest(Prompt prompt) { return request; } - public static Builder builder(Ai21Jurassic2ChatBedrockApi chatApi) { - return new Builder(chatApi); + @Override + public ChatOptions getDefaultOptions() { + return BedrockAi21Jurassic2ChatOptions.fromOptions(this.defaultOptions); } public static class Builder { @@ -108,15 +113,10 @@ public Builder withOptions(BedrockAi21Jurassic2ChatOptions options) { } public BedrockAi21Jurassic2ChatModel build() { - return new BedrockAi21Jurassic2ChatModel(chatApi, - options != null ? options : BedrockAi21Jurassic2ChatOptions.builder().build()); + return new BedrockAi21Jurassic2ChatModel(this.chatApi, + this.options != null ? this.options : BedrockAi21Jurassic2ChatOptions.builder().build()); } } - @Override - public ChatOptions getDefaultOptions() { - return BedrockAi21Jurassic2ChatOptions.fromOptions(this.defaultOptions); - } - } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java index eb8ce968aac..aa292edfa78 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,13 @@ package org.springframework.ai.bedrock.jurassic2; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.ChatOptions; -import java.util.List; +import org.springframework.ai.chat.prompt.ChatOptions; /** * Request body for the /complete endpoint of the Jurassic-2 API. @@ -101,12 +102,31 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { // Getters and setters + public static Builder builder() { + return new Builder(); + } + + public static BedrockAi21Jurassic2ChatOptions fromOptions(BedrockAi21Jurassic2ChatOptions fromOptions) { + return builder().withPrompt(fromOptions.getPrompt()) + .withNumResults(fromOptions.getNumResults()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withMinTokens(fromOptions.getMinTokens()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withStopSequences(fromOptions.getStopSequences()) + .withFrequencyPenaltyOptions(fromOptions.getFrequencyPenaltyOptions()) + .withPresencePenaltyOptions(fromOptions.getPresencePenaltyOptions()) + .withCountPenaltyOptions(fromOptions.getCountPenaltyOptions()) + .build(); + } + /** * Gets the prompt text for the model to continue. * @return The prompt text. */ public String getPrompt() { - return prompt; + return this.prompt; } /** @@ -122,7 +142,7 @@ public void setPrompt(String prompt) { * @return The number of results. */ public Integer getNumResults() { - return numResults; + return this.numResults; } /** @@ -139,7 +159,7 @@ public void setNumResults(Integer numResults) { */ @Override public Integer getMaxTokens() { - return maxTokens; + return this.maxTokens; } /** @@ -155,7 +175,7 @@ public void setMaxTokens(Integer maxTokens) { * @return The minimum number of tokens. */ public Integer getMinTokens() { - return minTokens; + return this.minTokens; } /** @@ -172,7 +192,7 @@ public void setMinTokens(Integer minTokens) { */ @Override public Double getTemperature() { - return temperature; + return this.temperature; } /** @@ -190,7 +210,7 @@ public void setTemperature(Double temperature) { */ @Override public Double getTopP() { - return topP; + return this.topP; } /** @@ -208,7 +228,7 @@ public void setTopP(Double topP) { */ @Override public Integer getTopK() { - return topK; + return this.topK; } /** @@ -225,7 +245,7 @@ public void setTopK(Integer topK) { */ @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } /** @@ -254,7 +274,7 @@ public void setFrequencyPenalty(Double frequencyPenalty) { * @return The frequency penalty object. */ public Penalty getFrequencyPenaltyOptions() { - return frequencyPenaltyOptions; + return this.frequencyPenaltyOptions; } /** @@ -283,7 +303,7 @@ public void setPresencePenalty(Double presencePenalty) { * @return The presence penalty object. */ public Penalty getPresencePenaltyOptions() { - return presencePenaltyOptions; + return this.presencePenaltyOptions; } /** @@ -299,7 +319,7 @@ public void setPresencePenaltyOptions(Penalty presencePenaltyOptions) { * @return The count penalty object. */ public Penalty getCountPenaltyOptions() { - return countPenaltyOptions; + return this.countPenaltyOptions; } /** @@ -316,8 +336,9 @@ public String getModel() { return null; } - public static Builder builder() { - return new Builder(); + @Override + public BedrockAi21Jurassic2ChatOptions copy() { + return fromOptions(this); } public static class Builder { @@ -325,62 +346,62 @@ public static class Builder { private final BedrockAi21Jurassic2ChatOptions request = new BedrockAi21Jurassic2ChatOptions(); public Builder withPrompt(String prompt) { - request.setPrompt(prompt); + this.request.setPrompt(prompt); return this; } public Builder withNumResults(Integer numResults) { - request.setNumResults(numResults); + this.request.setNumResults(numResults); return this; } public Builder withMaxTokens(Integer maxTokens) { - request.setMaxTokens(maxTokens); + this.request.setMaxTokens(maxTokens); return this; } public Builder withMinTokens(Integer minTokens) { - request.setMinTokens(minTokens); + this.request.setMinTokens(minTokens); return this; } public Builder withTemperature(Double temperature) { - request.setTemperature(temperature); + this.request.setTemperature(temperature); return this; } public Builder withTopP(Double topP) { - request.setTopP(topP); + this.request.setTopP(topP); return this; } public Builder withStopSequences(List stopSequences) { - request.setStopSequences(stopSequences); + this.request.setStopSequences(stopSequences); return this; } public Builder withTopK(Integer topKReturn) { - request.setTopK(topKReturn); + this.request.setTopK(topKReturn); return this; } public Builder withFrequencyPenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty frequencyPenalty) { - request.setFrequencyPenaltyOptions(frequencyPenalty); + this.request.setFrequencyPenaltyOptions(frequencyPenalty); return this; } public Builder withPresencePenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty presencePenalty) { - request.setPresencePenaltyOptions(presencePenalty); + this.request.setPresencePenaltyOptions(presencePenalty); return this; } public Builder withCountPenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty countPenalty) { - request.setCountPenaltyOptions(countPenalty); + this.request.setCountPenaltyOptions(countPenalty); return this; } public BedrockAi21Jurassic2ChatOptions build() { - return request; + return this.request; } } @@ -446,31 +467,12 @@ public Builder applyToEmojis(Boolean applyToEmojis) { } public Penalty build() { - return new Penalty(scale, applyToNumbers, applyToPunctuations, applyToStopwords, applyToWhitespaces, - applyToEmojis); + return new Penalty(this.scale, this.applyToNumbers, this.applyToPunctuations, this.applyToStopwords, + this.applyToWhitespaces, this.applyToEmojis); } } - } - @Override - public BedrockAi21Jurassic2ChatOptions copy() { - return fromOptions(this); - } - - public static BedrockAi21Jurassic2ChatOptions fromOptions(BedrockAi21Jurassic2ChatOptions fromOptions) { - return builder().withPrompt(fromOptions.getPrompt()) - .withNumResults(fromOptions.getNumResults()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withMinTokens(fromOptions.getMinTokens()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withStopSequences(fromOptions.getStopSequences()) - .withFrequencyPenaltyOptions(fromOptions.getFrequencyPenaltyOptions()) - .withPresencePenaltyOptions(fromOptions.getPresencePenaltyOptions()) - .withCountPenaltyOptions(fromOptions.getCountPenaltyOptions()) - .build(); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java index c3ab019d129..9fa7104cf37 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,16 +22,15 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; - import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse; import org.springframework.ai.model.ChatModelDescription; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - /** * Java client for the Bedrock Jurassic2 chat model. * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html @@ -110,6 +109,45 @@ public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider creden super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public Ai21Jurassic2ChatResponse chatCompletion(Ai21Jurassic2ChatRequest request) { + return this.internalInvocation(request, Ai21Jurassic2ChatResponse.class); + } + + /** + * Ai21 Jurassic2 models version. + */ + public enum Ai21Jurassic2ChatModel implements ChatModelDescription { + + /** + * ai21.j2-mid-v1 + */ + AI21_J2_MID_V1("ai21.j2-mid-v1"), + + /** + * ai21.j2-ultra-v1 + */ + AI21_J2_ULTRA_V1("ai21.j2-ultra-v1"); + + private final String id; + + Ai21Jurassic2ChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * AI21 Jurassic2 chat request parameters. * @@ -141,6 +179,10 @@ public record Ai21Jurassic2ChatRequest( @JsonProperty("presencePenalty") FloatScalePenalty presencePenalty, @JsonProperty("frequencyPenalty") IntegerScalePenalty frequencyPenalty) { + public static Builder builder(String prompt) { + return new Builder(prompt); + } + /** * Penalty with integer scale value. * @@ -192,11 +234,6 @@ public record FloatScalePenalty(@JsonProperty("scale") Float scale, @JsonProperty("applyToEmojis") boolean applyToEmojis) { } - - - public static Builder builder(String prompt) { - return new Builder(prompt); - } public static class Builder { private String prompt; private Double temperature; @@ -248,14 +285,14 @@ public Builder withFrequencyPenalty(IntegerScalePenalty frequencyPenalty) { public Ai21Jurassic2ChatRequest build() { return new Ai21Jurassic2ChatRequest( - prompt, - temperature, - topP, - maxTokens, - stopSequences, - countPenalty, - presencePenalty, - frequencyPenalty + this.prompt, + this.temperature, + this.topP, + this.maxTokens, + this.stopSequences, + this.countPenalty, + this.presencePenalty, + this.frequencyPenalty ); } } @@ -370,45 +407,6 @@ public record FinishReason( } } - /** - * Ai21 Jurassic2 models version. - */ - public enum Ai21Jurassic2ChatModel implements ChatModelDescription { - - /** - * ai21.j2-mid-v1 - */ - AI21_J2_MID_V1("ai21.j2-mid-v1"), - - /** - * ai21.j2-ultra-v1 - */ - AI21_J2_ULTRA_V1("ai21.j2-ultra-v1"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - Ai21Jurassic2ChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - } - - @Override - public Ai21Jurassic2ChatResponse chatCompletion(Ai21Jurassic2ChatRequest request) { - return this.internalInvocation(request, Ai21Jurassic2ChatResponse.class); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java index 51b83a7be07..1944b85ab25 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama; import java.util.List; @@ -23,13 +24,13 @@ import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.util.Assert; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java index ed50bd3c5ea..bdeb7543a73 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -22,8 +25,6 @@ import org.springframework.ai.chat.prompt.ChatOptions; -import java.util.List; - /** * @author Christian Tzolov * @author Thomas Vitale @@ -52,29 +53,11 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withMaxGenLen(Integer maxGenLen) { - this.options.setMaxGenLen(maxGenLen); - return this; - } - - public BedrockLlamaChatOptions build() { - return this.options; - } - + public static BedrockLlamaChatOptions fromOptions(BedrockLlamaChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withMaxGenLen(fromOptions.getMaxGenLen()) + .build(); } @Override @@ -149,11 +132,29 @@ public BedrockLlamaChatOptions copy() { return fromOptions(this); } - public static BedrockLlamaChatOptions fromOptions(BedrockLlamaChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withMaxGenLen(fromOptions.getMaxGenLen()) - .build(); + public static class Builder { + + private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withMaxGenLen(Integer maxGenLen) { + this.options.setMaxGenLen(maxGenLen); + return this; + } + + public BedrockLlamaChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java index a476f70cf73..4a76ee485e3 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama.api; +import java.time.Duration; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -28,8 +31,6 @@ import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; import org.springframework.ai.model.ChatModelDescription; -import java.time.Duration; - // @formatter:off /** * Java client for the Bedrock Llama chat model. @@ -109,6 +110,95 @@ public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPro super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public LlamaChatResponse chatCompletion(LlamaChatRequest request) { + return this.internalInvocation(request, LlamaChatResponse.class); + } + + @Override + public Flux chatCompletionStream(LlamaChatRequest request) { + return this.internalInvocationStream(request, LlamaChatResponse.class); + } + + /** + * Llama models version. + */ + public enum LlamaChatModel implements ChatModelDescription { + + /** + * meta.llama2-13b-chat-v1 + */ + LLAMA2_13B_CHAT_V1("meta.llama2-13b-chat-v1"), + + /** + * meta.llama2-70b-chat-v1 + */ + LLAMA2_70B_CHAT_V1("meta.llama2-70b-chat-v1"), + + /** + * meta.llama3-8b-instruct-v1:0 + */ + LLAMA3_8B_INSTRUCT_V1("meta.llama3-8b-instruct-v1:0"), + + /** + * meta.llama3-70b-instruct-v1:0 + */ + LLAMA3_70B_INSTRUCT_V1("meta.llama3-70b-instruct-v1:0"), + + /** + * meta.llama3-1-8b-instruct-v1:0 + */ + LLAMA3_1_8B_INSTRUCT_V1("meta.llama3-1-8b-instruct-v1:0"), + + /** + * meta.llama3-1-70b-instruct-v1:0 + */ + LLAMA3_1_70B_INSTRUCT_V1("meta.llama3-1-70b-instruct-v1:0"), + + /** + * meta.llama3-1-405b-instruct-v1:0 + */ + LLAMA3_1_405B_INSTRUCT_V1("meta.llama3-1-405b-instruct-v1:0"), + + /** + * meta.llama3-2-1b-instruct-v1:0 + */ + LLAMA3_2_1B_INSTRUCT_V1("meta.llama3-2-1b-instruct-v1:0"), + + /** + * meta.llama3-2-3b-instruct-v1:0 + */ + LLAMA3_2_3B_INSTRUCT_V1("meta.llama3-2-3b-instruct-v1:0"), + + /** + * meta.llama3-2-11b-instruct-v1:0 + */ + LLAMA3_2_11B_INSTRUCT_V1("meta.llama3-2-11b-instruct-v1:0"), + + /** + * meta.llama3-2-90b-instruct-v1:0 + */ + LLAMA3_2_90B_INSTRUCT_V1("meta.llama3-2-90b-instruct-v1:0"); + + private final String id; + + LlamaChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * LlamaChatRequest encapsulates the request parameters for the Meta Llama chat model. * @@ -162,10 +252,10 @@ public Builder withMaxGenLen(Integer maxGenLen) { public LlamaChatRequest build() { return new LlamaChatRequest( - prompt, - temperature, - topP, - maxGenLen + this.prompt, + this.temperature, + this.topP, + this.maxGenLen ); } } @@ -204,94 +294,5 @@ public enum StopReason { @JsonProperty("length") LENGTH } } - - /** - * Llama models version. - */ - public enum LlamaChatModel implements ChatModelDescription { - - /** - * meta.llama2-13b-chat-v1 - */ - LLAMA2_13B_CHAT_V1("meta.llama2-13b-chat-v1"), - - /** - * meta.llama2-70b-chat-v1 - */ - LLAMA2_70B_CHAT_V1("meta.llama2-70b-chat-v1"), - - /** - * meta.llama3-8b-instruct-v1:0 - */ - LLAMA3_8B_INSTRUCT_V1("meta.llama3-8b-instruct-v1:0"), - - /** - * meta.llama3-70b-instruct-v1:0 - */ - LLAMA3_70B_INSTRUCT_V1("meta.llama3-70b-instruct-v1:0"), - - /** - * meta.llama3-1-8b-instruct-v1:0 - */ - LLAMA3_1_8B_INSTRUCT_V1("meta.llama3-1-8b-instruct-v1:0"), - - /** - * meta.llama3-1-70b-instruct-v1:0 - */ - LLAMA3_1_70B_INSTRUCT_V1("meta.llama3-1-70b-instruct-v1:0"), - - /** - * meta.llama3-1-405b-instruct-v1:0 - */ - LLAMA3_1_405B_INSTRUCT_V1("meta.llama3-1-405b-instruct-v1:0"), - - /** - * meta.llama3-2-1b-instruct-v1:0 - */ - LLAMA3_2_1B_INSTRUCT_V1("meta.llama3-2-1b-instruct-v1:0"), - - /** - * meta.llama3-2-3b-instruct-v1:0 - */ - LLAMA3_2_3B_INSTRUCT_V1("meta.llama3-2-3b-instruct-v1:0"), - - /** - * meta.llama3-2-11b-instruct-v1:0 - */ - LLAMA3_2_11B_INSTRUCT_V1("meta.llama3-2-11b-instruct-v1:0"), - - /** - * meta.llama3-2-90b-instruct-v1:0 - */ - LLAMA3_2_90B_INSTRUCT_V1("meta.llama3-2-90b-instruct-v1:0"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - LlamaChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - } - - @Override - public LlamaChatResponse chatCompletion(LlamaChatRequest request) { - return this.internalInvocation(request, LlamaChatResponse.class); - } - - @Override - public Flux chatCompletionStream(LlamaChatRequest request) { - return this.internalInvocationStream(request, LlamaChatResponse.class); - } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index e6d55a03bdb..1003ef0443b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.util.List; @@ -24,13 +25,13 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.util.Assert; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java index d1187f11894..a5d06bdba5c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.util.List; @@ -20,11 +21,10 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * @author Christian Tzolov * @author Thomas Vitale @@ -59,39 +59,17 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private BedrockTitanChatOptions options = new BedrockTitanChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withMaxTokenCount(Integer maxTokenCount) { - this.options.maxTokenCount = maxTokenCount; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.stopSequences = stopSequences; - return this; - } - - public BedrockTitanChatOptions build() { - return this.options; - } - + public static BedrockTitanChatOptions fromOptions(BedrockTitanChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withMaxTokenCount(fromOptions.getMaxTokenCount()) + .withStopSequences(fromOptions.getStopSequences()) + .build(); } @Override public Double getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Double temperature) { @@ -100,7 +78,7 @@ public void setTemperature(Double temperature) { @Override public Double getTopP() { - return topP; + return this.topP; } public void setTopP(Double topP) { @@ -119,7 +97,7 @@ public void setMaxTokens(Integer maxTokens) { } public Integer getMaxTokenCount() { - return maxTokenCount; + return this.maxTokenCount; } public void setMaxTokenCount(Integer maxTokenCount) { @@ -128,7 +106,7 @@ public void setMaxTokenCount(Integer maxTokenCount) { @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } public void setStopSequences(List stopSequences) { @@ -164,12 +142,34 @@ public BedrockTitanChatOptions copy() { return fromOptions(this); } - public static BedrockTitanChatOptions fromOptions(BedrockTitanChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withMaxTokenCount(fromOptions.getMaxTokenCount()) - .withStopSequences(fromOptions.getStopSequences()) - .build(); + public static class Builder { + + private BedrockTitanChatOptions options = new BedrockTitanChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withMaxTokenCount(Integer maxTokenCount) { + this.options.maxTokenCount = maxTokenCount; + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.stopSequences = stopSequences; + return this; + } + + public BedrockTitanChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java index 84a646b8475..c07527b0eb4 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.util.ArrayList; @@ -50,12 +51,6 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel { private final TitanEmbeddingBedrockApi embeddingApi; - public enum InputType { - - TEXT, IMAGE - - } - /** * Titan Embedding API input types. Could be either text or image (encoded in base64). */ @@ -83,7 +78,7 @@ public float[] embed(Document document) { public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); if (request.getInstructions().size() != 1) { - logger.warn( + this.logger.warn( "Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); } @@ -113,7 +108,7 @@ private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, E public int dimensions() { if (this.inputType == InputType.IMAGE) { if (this.embeddingDimensions.get() < 0) { - this.embeddingDimensions.set(dimensions(this, embeddingApi.getModelId(), + this.embeddingDimensions.set(dimensions(this, this.embeddingApi.getModelId(), // small base64 encoded image "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")); } @@ -122,4 +117,10 @@ public int dimensions() { } + public enum InputType { + + TEXT, IMAGE + + } + } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java index 28757f3b78d..61b82dbdf97 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -39,23 +40,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions(); - - public Builder withInputType(InputType inputType) { - Assert.notNull(inputType, "input type can not be null."); - - this.options.setInputType(inputType); - return this; - } - - public BedrockTitanEmbeddingOptions build() { - return this.options; - } - - } - public InputType getInputType() { return this.inputType; } @@ -76,4 +60,21 @@ public Integer getDimensions() { return null; } + public static class Builder { + + private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions(); + + public Builder withInputType(InputType inputType) { + Assert.notNull(inputType, "input type can not be null."); + + this.options.setInputType(inputType); + return this; + } + + public BedrockTitanEmbeddingOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java index 85a1f10c7a5..19e76729de0 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan.api; import java.time.Duration; @@ -110,6 +111,55 @@ public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPro super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public TitanChatResponse chatCompletion(TitanChatRequest request) { + return this.internalInvocation(request, TitanChatResponse.class); + } + + @Override + public Flux chatCompletionStream(TitanChatRequest request) { + return this.internalInvocationStream(request, TitanChatResponseChunk.class); + } + + /** + * Titan models version. + */ + public enum TitanChatModel implements ChatModelDescription { + + /** + * amazon.titan-text-lite-v1 + */ + TITAN_TEXT_LITE_V1("amazon.titan-text-lite-v1"), + + /** + * amazon.titan-text-express-v1 + */ + TITAN_TEXT_EXPRESS_V1("amazon.titan-text-express-v1"), + + /** + * amazon.titan-text-premier-v1:0 + */ + TITAN_TEXT_PREMIER_V1("amazon.titan-text-premier-v1:0"); + + private final String id; + + TitanChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * TitanChatRequest encapsulates the request parameters for the Titan chat model. * @@ -121,6 +171,15 @@ public record TitanChatRequest( @JsonProperty("inputText") String inputText, @JsonProperty("textGenerationConfig") TextGenerationConfig textGenerationConfig) { + /** + * Create a new TitanChatRequest builder. + * @param inputText The prompt to use for the chat. + * @return A new TitanChatRequest builder. + */ + public static Builder builder(String inputText) { + return new Builder(inputText); + } + /** * Titan request text generation configuration. * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html @@ -141,15 +200,6 @@ public record TextGenerationConfig( @JsonProperty("stopSequences") List stopSequences) { } - /** - * Create a new TitanChatRequest builder. - * @param inputText The prompt to use for the chat. - * @return A new TitanChatRequest builder. - */ - public static Builder builder(String inputText) { - return new Builder(inputText); - } - public static class Builder { private final String inputText; private Double temperature; @@ -210,20 +260,6 @@ public record TitanChatResponse( @JsonProperty("inputTextTokenCount") Integer inputTextTokenCount, @JsonProperty("results") List results) { - /** - * Titan response result. - * - * @param tokenCount The number of tokens in the generated text. - * @param outputText The generated text. - * @param completionReason The reason the response finished being generated. - */ - @JsonInclude(Include.NON_NULL) - public record Result( - @JsonProperty("tokenCount") Integer tokenCount, - @JsonProperty("outputText") String outputText, - @JsonProperty("completionReason") CompletionReason completionReason) { - } - /** * The reason the response finished being generated. */ @@ -243,6 +279,20 @@ public enum CompletionReason { */ CONTENT_FILTERED } + + /** + * Titan response result. + * + * @param tokenCount The number of tokens in the generated text. + * @param outputText The generated text. + * @param completionReason The reason the response finished being generated. + */ + @JsonInclude(Include.NON_NULL) + public record Result( + @JsonProperty("tokenCount") Integer tokenCount, + @JsonProperty("outputText") String outputText, + @JsonProperty("completionReason") CompletionReason completionReason) { + } } /** @@ -263,54 +313,5 @@ public record TitanChatResponseChunk( @JsonProperty("completionReason") CompletionReason completionReason, @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { } - - /** - * Titan models version. - */ - public enum TitanChatModel implements ChatModelDescription { - - /** - * amazon.titan-text-lite-v1 - */ - TITAN_TEXT_LITE_V1("amazon.titan-text-lite-v1"), - - /** - * amazon.titan-text-express-v1 - */ - TITAN_TEXT_EXPRESS_V1("amazon.titan-text-express-v1"), - - /** - * amazon.titan-text-premier-v1:0 - */ - TITAN_TEXT_PREMIER_V1("amazon.titan-text-premier-v1:0"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - TitanChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - } - - @Override - public TitanChatResponse chatCompletion(TitanChatRequest request) { - return this.internalInvocation(request, TitanChatResponse.class); - } - - @Override - public Flux chatCompletionStream(TitanChatRequest request) { - return this.internalInvocationStream(request, TitanChatResponseChunk.class); - } } // @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java index 01968c81cbc..b94ccff9a26 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan.api; import java.time.Duration; -import java.util.List; import java.util.Map; import com.fasterxml.jackson.annotation.JsonInclude; @@ -27,7 +27,6 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse; import org.springframework.util.Assert; @@ -83,6 +82,42 @@ public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentia super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public TitanEmbeddingResponse embedding(TitanEmbeddingRequest request) { + return this.internalInvocation(request, TitanEmbeddingResponse.class); + } + + /** + * Titan Embedding model ids. + */ + public enum TitanEmbeddingModel { + /** + * amazon.titan-embed-image-v1 + */ + TITAN_EMBED_IMAGE_V1("amazon.titan-embed-image-v1"), + /** + * amazon.titan-embed-text-v1 + */ + TITAN_EMBED_TEXT_V1("amazon.titan-embed-text-v1"), + /** + * amazon.titan-embed-text-v2 + */ + TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0");; + + private final String id; + + TitanEmbeddingModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + } + /** * Titan Embedding request parameters. * @@ -143,44 +178,8 @@ public record TitanEmbeddingResponse( @JsonProperty("inputTextTokenCount") Integer inputTextTokenCount, @JsonProperty("embeddingsByType") Map embeddingsByType, @JsonProperty("message") Object message) { - - - } - - /** - * Titan Embedding model ids. - */ - public enum TitanEmbeddingModel { - /** - * amazon.titan-embed-image-v1 - */ - TITAN_EMBED_IMAGE_V1("amazon.titan-embed-image-v1"), - /** - * amazon.titan-embed-text-v1 - */ - TITAN_EMBED_TEXT_V1("amazon.titan-embed-text-v1"), - /** - * amazon.titan-embed-text-v2 - */ - TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0");; - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - TitanEmbeddingModel(String value) { - this.id = value; - } - } - - @Override - public TitanEmbeddingResponse embedding(TitanEmbeddingRequest request) { - return this.internalInvocation(request, TitanEmbeddingResponse.class); } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java index e0893036b3f..0ccd002d73e 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic; import java.time.Duration; @@ -22,7 +23,6 @@ import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; - import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -33,11 +33,11 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -70,8 +70,8 @@ class BedrockAnthropicChatModelIT { @Test void multipleStreamAttempts() { - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -98,12 +98,12 @@ void multipleStreamAttempts() { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -140,16 +140,13 @@ void mapOutputConvert() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Disabled @Test void beanOutputConverterRecords() { @@ -165,7 +162,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConvert.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -186,7 +183,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -202,6 +199,10 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java index 3cf8b344b74..37eaa0b5784 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic; import java.time.Duration; @@ -38,7 +39,7 @@ public class BedrockAnthropicCreateRequestTests { @Test public void createRequestWithChatOptions() { - var client = new BedrockAnthropicChatModel(anthropicChatApi, + var client = new BedrockAnthropicChatModel(this.anthropicChatApi, AnthropicChatOptions.builder() .withTemperature(66.6) .withTopK(66) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java index 8f0efe45acf..3638104cb63 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic.api; import java.time.Duration; @@ -28,11 +29,13 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; -import static org.assertj.core.api.Assertions.assertThat;; +import static org.assertj.core.api.Assertions.assertThat; + +; /** * @author Christian Tzolov @@ -57,7 +60,7 @@ public void chatCompletion() { .withTopK(10) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); System.out.println(response.completion()); assertThat(response).isNotNull(); @@ -67,7 +70,7 @@ public void chatCompletion() { assertThat(response.stop()).isEqualTo("\n\nHuman:"); assertThat(response.amazonBedrockInvocationMetrics()).isNull(); - logger.info("" + response); + this.logger.info("" + response); } @Test @@ -81,7 +84,7 @@ public void chatCompletionStream() { .withStopSequences(List.of("\n\nHuman:")) .build(); - Flux responseStream = anthropicChatApi.chatCompletionStream(request); + Flux responseStream = this.anthropicChatApi.chatCompletionStream(request); List responses = responseStream.collectList().block(); assertThat(responses).isNotNull(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java index 22644b42bd6..d2d906035f1 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3; import java.io.IOException; @@ -32,18 +33,18 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -72,8 +73,8 @@ class BedrockAnthropic3ChatModelIT { @Test void multipleStreamAttempts() { - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -100,12 +101,12 @@ void multipleStreamAttempts() { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -142,16 +143,13 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -166,7 +164,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -187,7 +185,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -211,7 +209,7 @@ void multiModalityTest() throws IOException { var userMessage = new UserMessage("Explain what do you see o this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel.call(new Prompt(List.of(userMessage))); + var response = this.chatModel.call(new Prompt(List.of(userMessage))); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket"); @@ -222,12 +220,16 @@ void stopSequencesWithEmptyContents() { Anthropic3ChatOptions chatOptions = new Anthropic3ChatOptions(); chatOptions.setStopSequences(List.of("Hello")); - var response = chatModel.call(new Prompt("hi", chatOptions)); + var response = this.chatModel.call(new Prompt("hi", chatOptions)); assertThat(response).isNotNull(); assertThat(response.getResults()).isEmpty(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java index 75551ca1cb7..bb76ae3d8e2 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3; +import java.time.Duration; +import java.util.List; + import org.junit.jupiter.api.Test; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; import org.springframework.ai.chat.prompt.Prompt; -import software.amazon.awssdk.regions.Region; - -import java.time.Duration; -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -37,7 +39,7 @@ public class BedrockAnthropic3CreateRequestTests { @Test public void createRequestWithChatOptions() { - var client = new BedrockAnthropic3ChatModel(anthropicChatApi, + var client = new BedrockAnthropic3ChatModel(this.anthropicChatApi, Anthropic3ChatOptions.builder() .withTemperature(66.6) .withTopK(66) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java index 48b89af37f3..55b054889c1 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,27 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3.api; +import java.time.Duration; +import java.util.List; +import java.util.stream.Collectors; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import java.time.Duration; -import java.util.List; -import java.util.stream.Collectors; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION; @@ -63,9 +65,9 @@ public void chatCompletion() { .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - logger.info("" + response.content()); + this.logger.info("" + response.content()); assertThat(response).isNotNull(); assertThat(response.content().get(0).text()).isNotEmpty(); @@ -75,7 +77,7 @@ public void chatCompletion() { assertThat(response.usage().inputTokens()).isGreaterThan(10); assertThat(response.usage().outputTokens()).isGreaterThan(100); - logger.info("" + response); + this.logger.info("" + response); } @Test @@ -103,9 +105,9 @@ public void chatMultiCompletion() { .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - logger.info("" + response.content()); + this.logger.info("" + response.content()); assertThat(response).isNotNull(); assertThat(response.content().get(0).text()).isNotEmpty(); assertThat(response.content().get(0).text()).contains("Blackbeard"); @@ -114,7 +116,7 @@ public void chatMultiCompletion() { assertThat(response.usage().inputTokens()).isGreaterThan(30); assertThat(response.usage().outputTokens()).isGreaterThan(200); - logger.info("" + response); + this.logger.info("" + response); } @Test @@ -129,7 +131,7 @@ public void chatCompletionStream() { .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) .build(); - Flux responseStream = anthropicChatApi + Flux responseStream = this.anthropicChatApi .chatCompletionStream(request); List responses = responseStream.collectList().block(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java index 92060ef5bf0..f3f33bbfc20 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.aot; +import java.util.Arrays; +import java.util.List; +import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; @@ -26,10 +32,6 @@ import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.List; -import java.util.Set; -import java.util.Arrays; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java index b6c0027da94..a1e3d7a3f7b 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.time.Duration; @@ -45,7 +46,7 @@ public class BedrockCohereChatCreateRequestTests { @Test public void createRequestWithChatOptions() { - var client = new BedrockCohereChatModel(chatApi, + var client = new BedrockCohereChatModel(this.chatApi, BedrockCohereChatOptions.builder() .withTemperature(66.6) .withTopK(66) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java index 5da9f8670d3..340b541b0f7 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.time.Duration; @@ -30,11 +31,11 @@ import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -65,8 +66,8 @@ class BedrockCohereChatModelIT { @Test void multipleStreamAttempts() { - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -95,10 +96,10 @@ void roleTest() { String name = "Bob"; String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -134,16 +135,13 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -157,7 +155,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -178,7 +176,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -194,6 +192,10 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java index 194b657ed79..03d3f0145c7 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.time.Duration; @@ -46,17 +47,17 @@ class BedrockCohereEmbeddingModelIT { @Test void singleEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); @@ -64,13 +65,13 @@ void batchEmbedding() { assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test void embeddingWthOptions() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel .call(new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), BedrockCohereEmbeddingOptions.builder().withInputType(InputType.SEARCH_DOCUMENT).build())); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -79,7 +80,7 @@ void embeddingWthOptions() { assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @SpringBootConfiguration diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java index 287eec21fcb..27c11af673a 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere.api; import java.time.Duration; @@ -32,7 +33,9 @@ import org.springframework.ai.model.ModelOptionsUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy;; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +; /** * @author Christian Tzolov @@ -86,7 +89,7 @@ public void chatCompletion() { .withTruncate(Truncate.NONE) .build(); - CohereChatResponse response = cohereChatApi.chatCompletion(request); + CohereChatResponse response = this.cohereChatApi.chatCompletion(request); assertThat(response).isNotNull(); assertThat(response.prompt()).isEqualTo(request.prompt()); @@ -111,7 +114,7 @@ public void chatCompletionStream() { .withTruncate(Truncate.NONE) .build(); - Flux responseStream = cohereChatApi.chatCompletionStream(request); + Flux responseStream = this.cohereChatApi.chatCompletionStream(request); List responses = responseStream.collectList().block(); assertThat(responses).isNotNull(); @@ -132,7 +135,7 @@ public void testStreamConfigurations() { .withStream(true) .build(); - assertThatThrownBy(() -> cohereChatApi.chatCompletion(streamRequest)) + assertThatThrownBy(() -> this.cohereChatApi.chatCompletion(streamRequest)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("The request must be configured to return the complete response!"); @@ -141,7 +144,7 @@ public void testStreamConfigurations() { .withStream(false) .build(); - assertThatThrownBy(() -> cohereChatApi.chatCompletionStream(notStreamRequest)) + assertThatThrownBy(() -> this.cohereChatApi.chatCompletionStream(notStreamRequest)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("The request must be configured to stream the response!"); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java index 83afec90d62..e8154344bb4 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere.api; import java.time.Duration; @@ -49,7 +50,7 @@ public void embedText() { List.of("I like to eat apples", "I like to eat oranges"), CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.NONE); - CohereEmbeddingResponse response = api.embedding(request); + CohereEmbeddingResponse response = this.api.embedding(request); assertThat(response).isNotNull(); assertThat(response.texts()).isEqualTo(request.texts()); @@ -64,7 +65,7 @@ public void embedTextWithTruncate() { List.of("I like to eat apples", "I like to eat oranges"), CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.START); - CohereEmbeddingResponse response = api.embedding(request); + CohereEmbeddingResponse response = this.api.embedding(request); assertThat(response).isNotNull(); assertThat(response.texts()).isEqualTo(request.texts()); @@ -74,7 +75,7 @@ public void embedTextWithTruncate() { request = new CohereEmbeddingRequest(List.of("I like to eat apples", "I like to eat oranges"), CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.END); - response = api.embedding(request); + response = this.api.embedding(request); assertThat(response).isNotNull(); assertThat(response.texts()).isEqualTo(request.texts()); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java index c7a6419772b..c0919cd03f4 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -29,10 +29,10 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -61,12 +61,12 @@ class BedrockAi21Jurassic2ChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -83,7 +83,7 @@ void testEmojiPenaltyFalse() { UserMessage userMessage = new UserMessage("Can you express happiness using an emoji like 😄 ?"); Prompt prompt = new Prompt(List.of(userMessage), options); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).matches(content -> content.contains("😄")); } @@ -98,12 +98,12 @@ void emojiPenaltyWhenTrueByDefaultApplyPenaltyTest() { .build(); UserMessage userMessage = new UserMessage("Can you express happiness using an emoji like 😄?"); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), options); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).doesNotContain("😄"); } @@ -120,7 +120,7 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -131,12 +131,12 @@ void mapOutputConverter() { @Test void simpleChatResponse() { UserMessage userMessage = new UserMessage("Tell me a joke about AI."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("AI"); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java index f3dedde9ef9..aa16faa7141 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.jurassic2.api; import java.time.Duration; @@ -20,7 +21,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -50,7 +50,7 @@ public void chatCompletion() { new Ai21Jurassic2ChatRequest.FloatScalePenalty(0.5f, true, true, true, true, true), new Ai21Jurassic2ChatRequest.IntegerScalePenalty(1, true, true, true, true, true)); - Ai21Jurassic2ChatResponse response = api.chatCompletion(request); + Ai21Jurassic2ChatResponse response = this.api.chatCompletion(request); assertThat(response).isNotNull(); assertThat(response.completions()).isNotEmpty(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java index c9239875eb6..168250f9d41 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama; import java.time.Duration; @@ -30,11 +31,11 @@ import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -65,8 +66,8 @@ class BedrockLlamaChatModelIT { @Test void multipleStreamAttempts() { - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a Toy joke?"))); - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a Toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -93,12 +94,12 @@ void multipleStreamAttempts() { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -134,16 +135,13 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -158,7 +156,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -179,7 +177,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -195,6 +193,10 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java index 3add11d14a6..48c81556ba4 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama; +import java.time.Duration; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -26,8 +28,6 @@ import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; import org.springframework.ai.chat.prompt.Prompt; -import java.time.Duration; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -45,7 +45,7 @@ public class BedrockLlamaCreateRequestTests { @Test public void createRequestWithChatOptions() { - var client = new BedrockLlamaChatModel(api, + var client = new BedrockLlamaChatModel(this.api, BedrockLlamaChatOptions.builder().withTemperature(66.6).withMaxGenLen(666).withTopP(0.66).build()); var request = client.createRequest(new Prompt("Test message content")); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java index 48844670c0b..664e021944e 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,23 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama.api; import java.time.Duration; import java.util.List; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; - -import com.fasterxml.jackson.databind.ObjectMapper; - import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -53,7 +53,7 @@ public void chatCompletion() { .withMaxGenLen(20) .build(); - LlamaChatResponse response = llamaChatApi.chatCompletion(request); + LlamaChatResponse response = this.llamaChatApi.chatCompletion(request); System.out.println(response.generation()); assertThat(response).isNotNull(); @@ -68,7 +68,7 @@ public void chatCompletion() { public void chatCompletionStream() { LlamaChatRequest request = new LlamaChatRequest("Hello, my name is", 0.9, 0.9, 20); - Flux responseStream = llamaChatApi.chatCompletionStream(request); + Flux responseStream = this.llamaChatApi.chatCompletionStream(request); List responses = responseStream.collectList().block(); assertThat(responses).isNotNull(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java index 90705ecc251..81c62d70b41 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.time.Duration; @@ -41,7 +42,7 @@ public class BedrockTitanChatModelCreateRequestTests { @Test public void createRequestWithChatOptions() { - var model = new BedrockTitanChatModel(api, + var model = new BedrockTitanChatModel(this.api, BedrockTitanChatOptions.builder() .withTemperature(66.6) .withTopP(0.66) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java index 45bc6fed1bb..b96991c38cb 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java @@ -31,11 +31,11 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java index 7670e8db9a2..ae4cdb6e3ff 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.io.IOException; @@ -20,9 +21,9 @@ import java.util.Base64; import java.util.List; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -37,8 +38,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import com.fasterxml.jackson.databind.ObjectMapper; - import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest @@ -51,12 +50,12 @@ class BedrockTitanEmbeddingModelIT { @Test void singleEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test @@ -65,12 +64,12 @@ void imageEmbedding() throws IOException { byte[] image = new DefaultResourceLoader().getResource("classpath:/spring_framework.png") .getContentAsByteArray(); - EmbeddingResponse embeddingResponse = embeddingModel + EmbeddingResponse embeddingResponse = this.embeddingModel .call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)), BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @SpringBootConfiguration diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java index 453e84490ef..094f182bb06 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan.api; import java.time.Duration; @@ -53,14 +54,14 @@ public class TitanChatBedrockApiIT { @Test public void chatCompletion() { - TitanChatResponse response = titanBedrockApi.chatCompletion(titanChatRequest); + TitanChatResponse response = this.titanBedrockApi.chatCompletion(this.titanChatRequest); assertThat(response.results()).hasSize(1); assertThat(response.results().get(0).outputText()).contains("Blackbeard"); } @Test public void chatCompletionStream() { - Flux response = titanBedrockApi.chatCompletionStream(titanChatRequest); + Flux response = this.titanBedrockApi.chatCompletionStream(this.titanChatRequest); List results = response.collectList().block(); assertThat(results.stream() diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java index 4f7813b2d76..f27a56bf61a 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan.api; import java.io.IOException; import java.time.Duration; import java.util.Base64; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -30,8 +31,6 @@ import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse; import org.springframework.core.io.DefaultResourceLoader; -import com.fasterxml.jackson.databind.ObjectMapper; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/models/spring-ai-huggingface/pom.xml b/models/spring-ai-huggingface/pom.xml index 4dac3c3ef78..9c74fb326b0 100644 --- a/models/spring-ai-huggingface/pom.xml +++ b/models/spring-ai-huggingface/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java index 8a8c8d92b96..affd1a32923 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.huggingface; import java.util.ArrayList; @@ -25,15 +26,15 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.huggingface.api.TextGenerationInferenceApi; import org.springframework.ai.huggingface.invoker.ApiClient; import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails; import org.springframework.ai.huggingface.model.GenerateParameters; import org.springframework.ai.huggingface.model.GenerateRequest; import org.springframework.ai.huggingface.model.GenerateResponse; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; -import org.springframework.ai.chat.prompt.Prompt; /** * An implementation of {@link ChatModel} that interfaces with HuggingFace Inference @@ -92,14 +93,15 @@ public ChatResponse call(Prompt prompt) { generateRequest.setInputs(prompt.getContents()); GenerateParameters generateParameters = new GenerateParameters(); // TODO - need to expose API to set parameters per call. - generateParameters.setMaxNewTokens(maxNewTokens); + generateParameters.setMaxNewTokens(this.maxNewTokens); generateRequest.setParameters(generateParameters); GenerateResponse generateResponse = this.textGenApi.generate(generateRequest); String generatedText = generateResponse.getGeneratedText(); List generations = new ArrayList<>(); AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails(); - Map detailsMap = objectMapper.convertValue(allOfGenerateResponseDetails, + Map detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails, new TypeReference>() { + }); Generation generation = new Generation(generatedText, detailsMap); generations.add(generation); @@ -111,7 +113,7 @@ public ChatResponse call(Prompt prompt) { * @return The maximum number of new tokens. */ public int getMaxNewTokens() { - return maxNewTokens; + return this.maxNewTokens; } /** diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java index 8e2a90d9da3..5f933a09c8c 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.huggingface; import org.springframework.boot.SpringBootConfiguration; diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java index 48f32b6ba47..9106ae98d37 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.huggingface.client; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.huggingface.client; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.huggingface.HuggingfaceChatModel; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; +import static org.assertj.core.api.Assertions.assertThat; + @SpringBootTest @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_CHAT_URL", matches = ".+") @@ -44,7 +46,7 @@ void helloWorldCompletion() { [/INST] """; Prompt prompt = new Prompt(mistral7bInstruct); - ChatResponse chatResponse = huggingfaceChatModel.call(prompt); + ChatResponse chatResponse = this.huggingfaceChatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); String expectedResponse = """ { diff --git a/models/spring-ai-minimax/pom.xml b/models/spring-ai-minimax/pom.xml index ee8ea6c0743..85824eb4996 100644 --- a/models/spring-ai-minimax/pom.xml +++ b/models/spring-ai-minimax/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index e8aca8bdb19..7d677db41e7 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -61,15 +72,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import static org.springframework.ai.minimax.api.MiniMaxApiConstants.TOOL_CALL_FUNCTION_TYPE; @@ -90,14 +92,14 @@ public class MiniMaxChatModel extends AbstractToolCallSupport implements ChatMod private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); /** - * The default options used for the chat completion requests. + * The retry template used to retry the MiniMax API calls. */ - private final MiniMaxChatOptions defaultOptions; + public final RetryTemplate retryTemplate; /** - * The retry template used to retry the MiniMax API calls. + * The default options used for the chat completion requests. */ - public final RetryTemplate retryTemplate; + private final MiniMaxChatOptions defaultOptions; /** * Low-level access to the MiniMax API. @@ -174,6 +176,40 @@ public MiniMaxChatModel(MiniMaxApi miniMaxApi, MiniMaxChatOptions options, this.observationRegistry = observationRegistry; } + private static Generation buildGeneration(Choice choice, Map metadata) { + List toolCalls = choice.message().toolCalls() == null ? List.of() + : choice.message() + .toolCalls() + .stream() + // the MiniMax's stream function calls response are really odd + // occasionally, tool call might get split. + // for example, id empty means the previous tool call is not finished, + // the toolCalls: + // [{id:'1',function:{name:'a'}},{id:'',function:{arguments:'[1]'}}] + // these need to be merged into [{id:'1', name:'a', arguments:'[1]'}] + // it worked before, maybe the model provider made some adjustments + .reduce(new ArrayList<>(), (acc, current) -> { + if (!acc.isEmpty() && current.id().isEmpty()) { + AssistantMessage.ToolCall prev = acc.get(acc.size() - 1); + acc.set(acc.size() - 1, new AssistantMessage.ToolCall(prev.id(), prev.type(), prev.name(), + current.function().arguments())); + } + else { + AssistantMessage.ToolCall currentToolCall = new AssistantMessage.ToolCall(current.id(), + current.type(), current.function().name(), current.function().arguments()); + acc.add(currentToolCall); + } + return acc; + }, (acc1, acc2) -> { + acc1.addAll(acc2); + return acc1; + }); + var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + return new Generation(assistantMessage, generationMetadata); + } + @Override public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); @@ -376,40 +412,6 @@ private Generation buildGeneration(ChatCompletionMessage message, ChatCompletion return new Generation(assistantMessage, generationMetadata); } - private static Generation buildGeneration(Choice choice, Map metadata) { - List toolCalls = choice.message().toolCalls() == null ? List.of() - : choice.message() - .toolCalls() - .stream() - // the MiniMax's stream function calls response are really odd - // occasionally, tool call might get split. - // for example, id empty means the previous tool call is not finished, - // the toolCalls: - // [{id:'1',function:{name:'a'}},{id:'',function:{arguments:'[1]'}}] - // these need to be merged into [{id:'1', name:'a', arguments:'[1]'}] - // it worked before, maybe the model provider made some adjustments - .reduce(new ArrayList<>(), (acc, current) -> { - if (!acc.isEmpty() && current.id().isEmpty()) { - AssistantMessage.ToolCall prev = acc.get(acc.size() - 1); - acc.set(acc.size() - 1, new AssistantMessage.ToolCall(prev.id(), prev.type(), prev.name(), - current.function().arguments())); - } - else { - AssistantMessage.ToolCall currentToolCall = new AssistantMessage.ToolCall(current.id(), - current.type(), current.function().name(), current.function().arguments()); - acc.add(currentToolCall); - } - return acc; - }, (acc1, acc2) -> { - acc1.addAll(acc2); - return acc1; - }); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); - String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); - return new Generation(assistantMessage, generationMetadata); - } - /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 7762b4a4fe6..0e10ca20e7f 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.minimax.api.MiniMaxApi; import org.springframework.ai.model.function.FunctionCallback; @@ -26,12 +34,6 @@ import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - /** * MiniMaxChatOptions represents the options for performing chat completion using the * MiniMax API. It provides methods to set and retrieve various options like model, @@ -157,119 +159,25 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected MiniMaxChatOptions options; - - public Builder() { - this.options = new MiniMaxChatOptions(); - } - - public Builder(MiniMaxChatOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withFrequencyPenalty(Double frequencyPenalty) { - this.options.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withN(Integer n) { - this.options.n = n; - return this; - } - - public Builder withPresencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } - - public Builder withResponseFormat(MiniMaxApi.ChatCompletionRequest.ResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withSeed(Integer seed) { - this.options.seed = seed; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withMaskSensitiveInfo(Boolean maskSensitiveInfo) { - this.options.maskSensitiveInfo = maskSensitiveInfo; - return this; - } - - public Builder withTools(List tools) { - this.options.tools = tools; - return this; - } - - public Builder withToolChoice(String toolChoice) { - this.options.toolChoice = toolChoice; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public MiniMaxChatOptions build() { - return this.options; - } - + public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { + return builder().withModel(fromOptions.getModel()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withN(fromOptions.getN()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withSeed(fromOptions.getSeed()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withMaskSensitiveInfo(fromOptions.getMaskSensitiveInfo()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); } @Override @@ -370,7 +278,7 @@ public void setTopP(Double topP) { } public Boolean getMaskSensitiveInfo() { - return maskSensitiveInfo; + return this.maskSensitiveInfo; } public void setMaskSensitiveInfo(Boolean maskSensitiveInfo) { @@ -405,7 +313,7 @@ public void setFunctionCallbacks(List functionCallbacks) { @Override public Set getFunctions() { - return functions; + return this.functions; } public void setFunctions(Set functionNames) { @@ -441,124 +349,157 @@ public void setToolContext(Map toolContext) { public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((n == null) ? 0 : n.hashCode()); - result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((seed == null) ? 0 : seed.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((maskSensitiveInfo == null) ? 0 : maskSensitiveInfo.hashCode()); - result = prime * result + ((tools == null) ? 0 : tools.hashCode()); - result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); - result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); - result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.n == null) ? 0 : this.n.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + result = prime * result + ((this.seed == null) ? 0 : this.seed.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.maskSensitiveInfo == null) ? 0 : this.maskSensitiveInfo.hashCode()); + result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); + result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); + result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); + result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); return result; } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } MiniMaxChatOptions other = (MiniMaxChatOptions) obj; if (this.model == null) { - if (other.model != null) + if (other.model != null) { return false; + } } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; + } if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) + if (other.frequencyPenalty != null) { return false; + } } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { return false; + } if (this.maxTokens == null) { - if (other.maxTokens != null) + if (other.maxTokens != null) { return false; + } } - else if (!this.maxTokens.equals(other.maxTokens)) + else if (!this.maxTokens.equals(other.maxTokens)) { return false; + } if (this.n == null) { - if (other.n != null) + if (other.n != null) { return false; + } } - else if (!this.n.equals(other.n)) + else if (!this.n.equals(other.n)) { return false; + } if (this.presencePenalty == null) { - if (other.presencePenalty != null) + if (other.presencePenalty != null) { return false; + } } - else if (!this.presencePenalty.equals(other.presencePenalty)) + else if (!this.presencePenalty.equals(other.presencePenalty)) { return false; + } if (this.responseFormat == null) { - if (other.responseFormat != null) + if (other.responseFormat != null) { return false; + } } - else if (!this.responseFormat.equals(other.responseFormat)) + else if (!this.responseFormat.equals(other.responseFormat)) { return false; + } if (this.seed == null) { - if (other.seed != null) + if (other.seed != null) { return false; + } } - else if (!this.seed.equals(other.seed)) + else if (!this.seed.equals(other.seed)) { return false; + } if (this.stop == null) { - if (other.stop != null) + if (other.stop != null) { return false; + } } - else if (!stop.equals(other.stop)) + else if (!this.stop.equals(other.stop)) { return false; + } if (this.temperature == null) { - if (other.temperature != null) + if (other.temperature != null) { return false; + } } - else if (!this.temperature.equals(other.temperature)) + else if (!this.temperature.equals(other.temperature)) { return false; + } if (this.topP == null) { - if (other.topP != null) + if (other.topP != null) { return false; + } } - else if (!topP.equals(other.topP)) + else if (!this.topP.equals(other.topP)) { return false; + } if (this.maskSensitiveInfo == null) { - if (other.maskSensitiveInfo != null) + if (other.maskSensitiveInfo != null) { return false; + } } - else if (!maskSensitiveInfo.equals(other.maskSensitiveInfo)) + else if (!this.maskSensitiveInfo.equals(other.maskSensitiveInfo)) { return false; + } if (this.tools == null) { - if (other.tools != null) + if (other.tools != null) { return false; + } } - else if (!tools.equals(other.tools)) + else if (!this.tools.equals(other.tools)) { return false; + } if (this.toolChoice == null) { - if (other.toolChoice != null) + if (other.toolChoice != null) { return false; + } } - else if (!toolChoice.equals(other.toolChoice)) + else if (!this.toolChoice.equals(other.toolChoice)) { return false; + } if (this.proxyToolCalls == null) { - if (other.proxyToolCalls != null) + if (other.proxyToolCalls != null) { return false; + } } - else if (!proxyToolCalls.equals(other.proxyToolCalls)) + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { return false; + } if (this.toolContext == null) { - if (other.toolContext != null) + if (other.toolContext != null) { return false; + } } - else if (!toolContext.equals(other.toolContext)) + else if (!this.toolContext.equals(other.toolContext)) { return false; + } return true; } @@ -568,25 +509,119 @@ public MiniMaxChatOptions copy() { return fromOptions(this); } - public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { - return builder().withModel(fromOptions.getModel()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withN(fromOptions.getN()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withSeed(fromOptions.getSeed()) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withMaskSensitiveInfo(fromOptions.getMaskSensitiveInfo()) - .withTools(fromOptions.getTools()) - .withToolChoice(fromOptions.getToolChoice()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); + public static class Builder { + + protected MiniMaxChatOptions options; + + public Builder() { + this.options = new MiniMaxChatOptions(); + } + + public Builder(MiniMaxChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withN(Integer n) { + this.options.n = n; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(MiniMaxApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withSeed(Integer seed) { + this.options.seed = seed; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withMaskSensitiveInfo(Boolean maskSensitiveInfo) { + this.options.maskSensitiveInfo = maskSensitiveInfo; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public MiniMaxChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java index 0ba752c389a..1882607e068 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; +import java.util.ArrayList; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -39,9 +44,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.List; - /** * MiniMax Embedding Model implementation. * diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java index d265e2dd687..9dffe18d127 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -42,6 +44,21 @@ public static Builder builder() { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + public static class Builder { protected MiniMaxEmbeddingOptions options; @@ -61,19 +78,4 @@ public MiniMaxEmbeddingOptions build() { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - @JsonIgnore - public Integer getDimensions() { - return null; - } - } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java index 129eb2d76c7..01d7fb6206e 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.aot; import org.springframework.ai.minimax.api.MiniMaxApi; diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java index c254d07d6d0..3216f694056 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; import java.util.List; @@ -21,6 +22,13 @@ import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonValue; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; @@ -35,14 +43,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonValue; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - // @formatter:off /** * Single class implementation of the MiniMax Chat Completion API and @@ -62,6 +62,8 @@ public class MiniMaxApi { private final WebClient webClient; + private final MiniMaxStreamFunctionCallingHelper chunkMerger = new MiniMaxStreamFunctionCallingHelper(); + /** * Create a new chat completion api with default base URL. * @@ -119,6 +121,99 @@ public MiniMaxApi(String baseUrl, String miniMaxToken, RestClient.Builder restCl .build(); } + public static String getTextContent(List content) { + return content.stream() + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); + } + + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/text/chatcompletion_v2") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * + * @param chatRequest The chat completion request. Must have the stream property set to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v1/text/chatcompletion_v2") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + .flatMap(mono -> mono); + } + + /** + * Creates an embedding vector representing the input text or token array. + * + * @param embeddingRequest The embedding request. + * @return Returns {@link EmbeddingList}. + * + */ + public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.texts(), "The input can not be null."); + + Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); + + return this.restClient.post() + .uri("/v1/embeddings") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + /** * MiniMax Chat Completion Models: * MiniMax Model. @@ -141,7 +236,7 @@ public enum ChatModel implements ChatModelDescription { } public String getValue() { - return value; + return this.value; } @Override @@ -150,6 +245,85 @@ public String getName() { } } + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") TOOL_CALL + } + + /** + * MiniMax Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + /** + * DIMENSION: 1536 + */ + Embo_01("embo-01"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + } + + /** + * MiniMax Embeddings Types + */ + public enum EmbeddingType { + + /** + * DB, used to generate vectors and store them in the library (as retrieved text) + */ + DB("db"), + + /** + * Query, used to generate vectors for queries (when used as retrieval text) + */ + Query("query"); + + @JsonValue + public final String value; + + EmbeddingType(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + } + /** * Represents a tool the model may call. Currently, only functions are supported as a tool. * @@ -382,6 +556,15 @@ public record ChatCompletionMessage( @JsonProperty("tool_call_id") String toolCallId, @JsonProperty("tool_calls") List toolCalls) { + /** + * Create a chat completion message with the given content and role. All other fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null); + } + /** * Get message content as String. */ @@ -395,15 +578,6 @@ public String content() { throw new IllegalStateException("The content is not a string!"); } - /** - * Create a chat completion message with the given content and role. All other fields are null. - * @param content The contents of the message. - * @param role The role of the author of this message. - */ - public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null); - } - /** * The role of the author of this message. */ @@ -441,22 +615,6 @@ public record MediaContent( @JsonProperty("text") String text, @JsonProperty("image_url") ImageUrl imageUrl) { - /** - * @param url Either a URL of the image or the base64 encoded image data. - * The base64 encoded image data must have a special prefix in the following format: - * "data:{mimetype};base64,{base64-encoded-image-data}". - * @param detail Specifies the detail level of the image. - */ - @JsonInclude(Include.NON_NULL) - public record ImageUrl( - @JsonProperty("url") String url, - @JsonProperty("detail") String detail) { - - public ImageUrl(String url) { - this(url, null); - } - } - /** * Shortcut constructor for a text content. * @param text The text content of the message. @@ -472,6 +630,22 @@ public MediaContent(String text) { public MediaContent(ImageUrl imageUrl) { this("image_url", null, imageUrl); } + + /** + * @param url Either a URL of the image or the base64 encoded image data. + * The base64 encoded image data must have a special prefix in the following format: + * "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl( + @JsonProperty("url") String url, + @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + } } /** * The relevant tool call. @@ -501,43 +675,6 @@ public record ChatCompletionFunction( } } - public static String getTextContent(List content) { - return content.stream() - .filter(c -> "text".equals(c.type())) - .map(ChatCompletionMessage.MediaContent::text) - .reduce("", (a, b) -> a + b); - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("content_filter") CONTENT_FILTER, - /** - * The model called a tool. - */ - @JsonProperty("tool_calls") TOOL_CALLS, - /** - * (deprecated) The model called a function. - */ - @JsonProperty("function_call") FUNCTION_CALL, - /** - * Only for compatibility with Mistral AI API. - */ - @JsonProperty("tool_call") TOOL_CALL - } - /** * Represents a chat completion response returned by model, based on the provided input. * @@ -689,118 +826,6 @@ public record ChunkChoice( } } - /** - * Creates a model response for the given chat conversation. - * - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/text/chatcompletion_v2") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - private final MiniMaxStreamFunctionCallingHelper chunkMerger = new MiniMaxStreamFunctionCallingHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * - * @param chatRequest The chat completion request. Must have the stream property set to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri("/v1/text/chatcompletion_v2") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null), - (previous, current) -> this.chunkMerger.merge(previous, current)); - return List.of(monoChunk); - }) - .flatMap(mono -> mono); - } - - /** - * MiniMax Embeddings Models: - * Embeddings. - */ - public enum EmbeddingModel { - - /** - * DIMENSION: 1536 - */ - Embo_01("embo-01"); - - public final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - } - - /** - * MiniMax Embeddings Types - */ - public enum EmbeddingType { - - /** - * DB, used to generate vectors and store them in the library (as retrieved text) - */ - DB("db"), - - /** - * Query, used to generate vectors for queries (when used as retrieval text) - */ - Query("query"); - - @JsonValue - public final String value; - - EmbeddingType(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - } - /** * Creates an embedding vector representing the input text. * @@ -890,30 +915,5 @@ public record EmbeddingList( @JsonProperty("total_tokens") Integer totalTokens) { } - /** - * Creates an embedding vector representing the input text or token array. - * - * @param embeddingRequest The embedding request. - * @return Returns {@link EmbeddingList}. - * - */ - public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.texts(), "The input can not be null."); - - Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); - - return this.restClient.post() - .uri("/v1/embeddings") - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - } // @formatter:on diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java index a8ed1b34a06..c83d1a4486b 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.minimax.api; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxStreamFunctionCallingHelper.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxStreamFunctionCallingHelper.java index 82b2eca12b4..24a71ec0f48 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxStreamFunctionCallingHelper.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxStreamFunctionCallingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionFinishReason; @@ -26,9 +30,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.util.ArrayList; -import java.util.List; - /** * Helper class to support Streaming function calling. It can merge the streamed * ChatCompletionChunk in case of function calling message. diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java index a720be0e425..cb8a5a74a0b 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -26,10 +27,6 @@ */ public class MiniMaxUsage implements Usage { - public static MiniMaxUsage from(MiniMaxApi.Usage usage) { - return new MiniMaxUsage(usage); - } - private final MiniMaxApi.Usage usage; protected MiniMaxUsage(MiniMaxApi.Usage usage) { @@ -37,6 +34,10 @@ protected MiniMaxUsage(MiniMaxApi.Usage usage) { this.usage = usage; } + public static MiniMaxUsage from(MiniMaxApi.Usage usage) { + return new MiniMaxUsage(usage); + } + protected MiniMaxApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java index 221a6bccbd4..5c50ddf3095 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.minimax.api.MiniMaxApi; import org.springframework.ai.minimax.api.MockWeatherService; import org.springframework.ai.model.function.FunctionCallbackWrapper; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxTestConfiguration.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxTestConfiguration.java index 8a7914da9f2..0493fba621b 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxTestConfiguration.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiIT.java index 60812302da7..52cbd95ae2d 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; +import java.util.List; +import java.util.Objects; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage; @@ -24,10 +30,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest; import org.springframework.ai.minimax.api.MiniMaxApi.EmbeddingList; import org.springframework.http.ResponseEntity; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; import static org.assertj.core.api.Assertions.assertThat; @@ -42,7 +44,7 @@ public class MiniMaxApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = miniMaxApi + ResponseEntity response = this.miniMaxApi .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-4-air", 0.7, false)); assertThat(response).isNotNull(); @@ -52,7 +54,7 @@ void chatCompletionEntity() { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = miniMaxApi + Flux response = this.miniMaxApi .chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-4-air", 0.7, true)); assertThat(response).isNotNull(); @@ -61,7 +63,8 @@ void chatCompletionStream() { @Test void embeddings() { - ResponseEntity response = miniMaxApi.embeddings(new MiniMaxApi.EmbeddingRequest("Hello world")); + ResponseEntity response = this.miniMaxApi + .embeddings(new MiniMaxApi.EmbeddingRequest("Hello world")); assertThat(response).isNotNull(); assertThat(Objects.requireNonNull(response.getBody()).vectors()).hasSize(1); diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java index d878a898d5b..fbbf900667d 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,17 @@ package org.springframework.ai.minimax.api; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role; @@ -31,10 +36,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi.FunctionTool.Type; import org.springframework.http.ResponseEntity; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -49,6 +50,15 @@ public class MiniMaxApiToolFunctionCallIT { MiniMaxApi miniMaxApi = new MiniMaxApi(System.getenv("MINIMAX_API_KEY")); + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("null") @Test public void toolFunctionCall() { @@ -89,7 +99,7 @@ public void toolFunctionCall() { org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_Chat.getValue(), List.of(functionTool), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion = miniMaxApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.miniMaxApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -108,7 +118,7 @@ public void toolFunctionCall() { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, @@ -119,9 +129,9 @@ public void toolFunctionCall() { var functionResponseRequest = new ChatCompletionRequest(messages, org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_Chat.getValue(), 0.5); - ResponseEntity chatCompletion2 = miniMaxApi.chatCompletionEntity(functionResponseRequest); + ResponseEntity chatCompletion2 = this.miniMaxApi.chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); @@ -146,7 +156,7 @@ public void webSearchToolFunctionCall() { org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), List.of(functionTool), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion = miniMaxApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.miniMaxApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -158,13 +168,4 @@ public void webSearchToolFunctionCall() { assertThat(assistantMessage.content()).contains("40"); } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java index b20bada56c4..d2099ef8457 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; +import java.util.List; +import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.minimax.MiniMaxChatModel; @@ -41,10 +47,6 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -58,25 +60,6 @@ @ExtendWith(MockitoExtension.class) public class MiniMaxRetryTests { - private class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -89,13 +72,14 @@ public void onError(RetryContext context, RetryCallback @BeforeEach public void beforeEach() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); - - chatModel = new MiniMaxChatModel(miniMaxApi, MiniMaxChatOptions.builder().build(), null, retryTemplate); - embeddingModel = new MiniMaxEmbeddingModel(miniMaxApi, MetadataMode.EMBED, - MiniMaxEmbeddingOptions.builder().build(), retryTemplate); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); + + this.chatModel = new MiniMaxChatModel(this.miniMaxApi, MiniMaxChatOptions.builder().build(), null, + this.retryTemplate); + this.embeddingModel = new MiniMaxEmbeddingModel(this.miniMaxApi, MetadataMode.EMBED, + MiniMaxEmbeddingOptions.builder().build(), this.retryTemplate); } @Test @@ -106,24 +90,24 @@ public void miniMaxChatTransientError() { ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, null, new MiniMaxApi.Usage(10, 10, 10)); - when(miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void miniMaxChatNonTransientError() { - when(miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -134,24 +118,24 @@ public void miniMaxChatStreamTransientError() { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, null); - when(miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void miniMaxChatStreamNonTransientError() { - when(miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).collectList().block()); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); } @Test @@ -159,25 +143,45 @@ public void miniMaxEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(new float[] { 9.9f, 8.8f }), "model", 10); - when(miniMaxApi.embeddings(isA(EmbeddingRequest.class))) + when(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingModel + var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void miniMaxEmbeddingNonTransientError() { - when(miniMaxApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingModel + when(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } + private class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java index d2f4a9e53d0..0d3b164524c 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,31 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Geng Rong */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } /** @@ -65,28 +71,25 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } -} \ No newline at end of file +} diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatModelObservationIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatModelObservationIT.java index 486456fa257..bd18fe5a38c 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatModelObservationIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class MiniMaxChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -79,7 +81,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -102,7 +104,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -122,7 +124,7 @@ void observationForStreamingChatOperation() { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java index 23daaf5cbc6..024de6f262a 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java @@ -1,9 +1,32 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.minimax.chat; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -13,12 +36,6 @@ import org.springframework.ai.minimax.MiniMaxChatModel; import org.springframework.ai.minimax.MiniMaxChatOptions; import org.springframework.ai.minimax.api.MiniMaxApi; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat; @@ -42,7 +59,7 @@ void testMarkSensitiveInfo() { List messages = new ArrayList<>(List.of(userMessage)); // markSensitiveInfo is enabled by default - ChatResponse response = chatModel.call(new Prompt(messages)); + ChatResponse response = this.chatModel.call(new Prompt(messages)); String responseContent = response.getResult().getOutput().getContent(); assertThat(responseContent).contains("133-**"); @@ -50,7 +67,7 @@ void testMarkSensitiveInfo() { var chatOptions = MiniMaxChatOptions.builder().withMaskSensitiveInfo(false).build(); - ChatResponse unmaskResponse = chatModel.call(new Prompt(messages, chatOptions)); + ChatResponse unmaskResponse = this.chatModel.call(new Prompt(messages, chatOptions)); String unmaskResponseContent = unmaskResponse.getResult().getOutput().getContent(); assertThat(unmaskResponseContent).contains("133-12345678"); @@ -80,7 +97,7 @@ void testWebSearch() { .withTools(functionTool) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, options)); + ChatResponse response = this.chatModel.call(new Prompt(messages, options)); String responseContent = response.getResult().getOutput().getContent(); assertThat(responseContent).contains("40"); @@ -110,7 +127,7 @@ void testWebSearchStream() { .withTools(functionTool) .build(); - Flux response = chatModel.stream(new Prompt(messages, options)); + Flux response = this.chatModel.stream(new Prompt(messages, options)); String content = Objects.requireNonNull(response.collectList().block()) .stream() .map(ChatResponse::getResults) diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/EmbeddingIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/EmbeddingIT.java index 551c9fef6b7..2ce7e934a4f 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/EmbeddingIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/EmbeddingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.embedding; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.minimax.MiniMaxEmbeddingModel; import org.springframework.ai.minimax.MiniMaxTestConfiguration; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -39,27 +41,27 @@ class EmbeddingIT { @Test void defaultEmbedding() { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World", "HI")); + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI")); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); assertThat(embeddingResponse.getResults().get(1)).isNotNull(); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1536); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java index 1c0a8bfb087..51336796c7e 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.embedding; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -35,8 +39,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -62,13 +64,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-mistral-ai/pom.xml b/models/spring-ai-mistral-ai/pom.xml index ea6ba209718..5cae1c54039 100644 --- a/models/spring-ai-mistral-ai/pom.xml +++ b/models/spring-ai-mistral-ai/pom.xml @@ -1,33 +1,50 @@ - - 4.0.0 - - org.springframework.ai - spring-ai - 1.0.0-SNAPSHOT - ../../pom.xml - - spring-ai-mistral-ai - jar - Spring AI Model - Mistral AI - Mistral AI models support - https://github.com/spring-projects/spring-ai + - - https://github.com/spring-projects/spring-ai - git://github.com/spring-projects/spring-ai.git - git@github.com:spring-projects/spring-ai.git - + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-mistral-ai + jar + Spring AI Model - Mistral AI + Mistral AI models support + https://github.com/spring-projects/spring-ai - + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + - - - org.springframework.ai - spring-ai-core - ${project.parent.version} - + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + org.springframework.ai @@ -35,24 +52,24 @@ ${project.parent.version} - - - org.springframework - spring-context-support - + + + org.springframework + spring-context-support + - - org.springframework.boot - spring-boot-starter-logging - + + org.springframework.boot + spring-boot-starter-logging + - - - org.springframework.ai - spring-ai-test - ${project.version} - test - + + + org.springframework.ai + spring-ai-test + ${project.version} + test + io.micrometer @@ -60,6 +77,6 @@ test - + diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index bad45cdd4ad..c699962f933 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.HashSet; @@ -26,13 +27,20 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.model.*; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; @@ -59,9 +67,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * Represents a Mistral AI Chat Model. * @@ -74,10 +79,10 @@ */ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel { - private final Logger logger = LoggerFactory.getLogger(getClass()); - private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private final Logger logger = LoggerFactory.getLogger(getClass()); + /** * The default options used for the chat completion requests. */ @@ -140,6 +145,17 @@ public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions option this.observationRegistry = observationRegistry; } + public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { + Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); + MistralAiUsage usage = MistralAiUsage.from(result.usage()); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withModel(result.model()) + .withUsage(usage) + .withKeyValue("created", result.created()) + .build(); + } + @Override public ChatResponse call(Prompt prompt) { @@ -156,13 +172,13 @@ public ChatResponse call(Prompt prompt) { this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = retryTemplate + ResponseEntity completionEntity = this.retryTemplate .execute(ctx -> this.mistralAiApi.chatCompletionEntity(request)); ChatCompletion chatCompletion = completionEntity.getBody(); if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); + this.logger.warn("No chat completion returned for prompt: {}", prompt); return new ChatResponse(List.of()); } @@ -213,7 +229,7 @@ public Flux stream(Prompt prompt) { observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); - Flux completionChunks = retryTemplate + Flux completionChunks = this.retryTemplate .execute(ctx -> this.mistralAiApi.chatCompletionStream(request)); // For chunked responses, only the first chunk contains the choice role. @@ -250,7 +266,7 @@ public Flux stream(Prompt prompt) { } } catch (Exception e) { - logger.error("Error processing chat completion", e); + this.logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } })); @@ -294,17 +310,6 @@ private Generation buildGeneration(Choice choice, Map metadata) return new Generation(assistantMessage, generationMetadata); } - public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { - Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); - MistralAiUsage usage = MistralAiUsage.from(result.usage()); - return ChatResponseMetadata.builder() - .withId(result.id()) - .withModel(result.model()) - .withUsage(usage) - .withKeyValue("created", result.created()) - .build(); - } - private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) { List choices = chunk.choices() .stream() diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index f2d8185232f..d256265bd43 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.ArrayList; @@ -25,10 +26,11 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; -import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; @@ -148,101 +150,22 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final MistralAiChatOptions options = new MistralAiChatOptions(); - - public Builder withModel(String model) { - this.options.setModel(model); - return this; - } - - public Builder withModel(MistralAiApi.ChatModel chatModel) { - this.options.setModel(chatModel.getName()); - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.setMaxTokens(maxTokens); - return this; - } - - public Builder withSafePrompt(Boolean safePrompt) { - this.options.setSafePrompt(safePrompt); - return this; - } - - public Builder withRandomSeed(Integer randomSeed) { - this.options.setRandomSeed(randomSeed); - return this; - } - - public Builder withStop(List stop) { - this.options.setStop(stop); - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withResponseFormat(ResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withTools(List tools) { - this.options.tools = tools; - return this; - } - - public Builder withToolChoice(ToolChoice toolChoice) { - this.options.toolChoice = toolChoice; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public MistralAiChatOptions build() { - return this.options; - } - + public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) { + return builder().withModel(fromOptions.getModel()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withSafePrompt(fromOptions.getSafePrompt()) + .withRandomSeed(fromOptions.getRandomSeed()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); } @Override @@ -306,22 +229,22 @@ public void setStop(List stop) { this.stop = stop; } - public void setTools(List tools) { - this.tools = tools; - } - public List getTools() { return this.tools; } - public void setToolChoice(ToolChoice toolChoice) { - this.toolChoice = toolChoice; + public void setTools(List tools) { + this.tools = tools; } public ToolChoice getToolChoice() { return this.toolChoice; } + public void setToolChoice(ToolChoice toolChoice) { + this.toolChoice = toolChoice; + } + @Override public Double getTemperature() { return this.temperature; @@ -404,38 +327,23 @@ public MistralAiChatOptions copy() { return fromOptions(this); } - public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) { - return builder().withModel(fromOptions.getModel()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withSafePrompt(fromOptions.getSafePrompt()) - .withRandomSeed(fromOptions.getRandomSeed()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withStop(fromOptions.getStop()) - .withTools(fromOptions.getTools()) - .withToolChoice(fromOptions.getToolChoice()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); - } - @Override public int hashCode() { - return Objects.hash(model, temperature, topP, maxTokens, safePrompt, randomSeed, responseFormat, stop, tools, - toolChoice, functionCallbacks, functions, proxyToolCalls, toolContext); + return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed, + this.responseFormat, this.stop, this.tools, this.toolChoice, this.functionCallbacks, this.functions, + this.proxyToolCalls, this.toolContext); } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; + } - if (obj == null || getClass() != obj.getClass()) + if (obj == null || getClass() != obj.getClass()) { return false; + } MistralAiChatOptions other = (MistralAiChatOptions) obj; @@ -451,4 +359,101 @@ public boolean equals(Object obj) { && Objects.equals(this.toolContext, other.toolContext); } + public static class Builder { + + private final MistralAiChatOptions options = new MistralAiChatOptions(); + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withModel(MistralAiApi.ChatModel chatModel) { + this.options.setModel(chatModel.getName()); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder withSafePrompt(Boolean safePrompt) { + this.options.setSafePrompt(safePrompt); + return this; + } + + public Builder withRandomSeed(Integer randomSeed) { + this.options.setRandomSeed(randomSeed); + return this; + } + + public Builder withStop(List stop) { + this.options.setStop(stop); + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withResponseFormat(ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(ToolChoice toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public MistralAiChatOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 197882130cb..c3f0a13e412 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.List; @@ -23,7 +24,13 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.embedding.*; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java index 7abfa01fc81..6409b05ca84 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java index cd6bcfa40f7..6ad65d426c9 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.aot; import org.springframework.ai.mistralai.api.MistralAiApi; diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index a6e156f4644..41277a0c07f 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api; import java.util.Arrays; @@ -26,12 +27,12 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.observation.conventions.AiProvider; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.core.ParameterizedTypeReference; @@ -62,15 +63,17 @@ */ public class MistralAiApi { - private static final String DEFAULT_BASE_URL = "https://api.mistral.ai"; - public static final String PROVIDER_NAME = AiProvider.MISTRAL_AI.value(); + private static final String DEFAULT_BASE_URL = "https://api.mistral.ai"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; - private WebClient webClient; + private final WebClient webClient; + + private final MistralAiStreamFunctionCallingHelper chunkMerger = new MistralAiStreamFunctionCallingHelper(); /** * Create a new client api with DEFAULT_BASE_URL @@ -112,6 +115,201 @@ public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder r this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); } + /** + * Creates an embedding vector representing the input text or token array. + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + * @param Type of the entity in the data list. Can be a {@link String} or + * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single + * request, You can pass a {@link List} of {@link String} or {@link List} of + * {@link List} of tokens. For example: + * + *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
+ */ + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + // The input must not an empty string, and any array must be 1024 dimensions or + // less. + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 1024, "The list must be 1024 dimensions or less"); + Assert.isTrue( + list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, + "The input must be either a String, or a List of Strings or list of list of integers."); + } + + return this.restClient.post() + .uri("/v1/embeddings") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + + }); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v1/chat/completions") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono mono1 = window.reduce(new ChatCompletionChunk(null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(mono1); + }) + .flatMap(mono -> mono); + } + + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + + // @formatter:off + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("model_length") MODEL_LENGTH, + /** + * + */ + @JsonProperty("error") ERROR, + /** + * The model requested a tool call. + */ + @JsonProperty("tool_calls") TOOL_CALLS + // @formatter:on + + } + + /** + * List of well-known Mistral chat models. + * https://docs.mistral.ai/platform/endpoints/#mistral-ai-generative-models + * + *

+ * Mistral AI provides two types of models: open-weights models (Mistral 7B, Mixtral + * 8x7B, Mixtral 8x22B) and optimized commercial models (Mistral Small, Mistral + * Medium, Mistral Large, and Mistral Embeddings). + */ + public enum ChatModel implements ChatModelDescription { + + // @formatter:off + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MISTRAL_7B + TINY("open-mistral-7b"), + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MIXTRAL_7B + MIXTRAL("open-mixtral-8x7b"), + OPEN_MISTRAL_7B("open-mistral-7b"), + OPEN_MIXTRAL_7B("open-mixtral-8x7b"), + OPEN_MIXTRAL_22B("open-mixtral-8x22b"), + SMALL("mistral-small-latest"), + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Mistral is removing this model + MEDIUM("mistral-medium-latest"), + LARGE("mistral-large-latest"); + // @formatter:on + + private final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + @Override + public String getName() { + return this.value; + } + + } + + /** + * List of well-known Mistral embedding models. + * https://docs.mistral.ai/platform/endpoints/#mistral-ai-embedding-model + */ + public enum EmbeddingModel { + + // @formatter:off + EMBED("mistral-embed"); + // @formatter:on + + private final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + /** * Represents a tool the model may call. Currently, only functions are supported as a * tool. @@ -168,7 +366,9 @@ public record Function(@JsonProperty("description") String description, @JsonPro public Function(String description, String name, String jsonSchema) { this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema)); } + } + } /** @@ -218,26 +418,29 @@ public Embedding(Integer index, float[] embedding) { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Embedding embedding1)) + } + if (!(o instanceof Embedding embedding1)) { return false; - return Objects.equals(index, embedding1.index) && Arrays.equals(embedding, embedding1.embedding) - && Objects.equals(object, embedding1.object); + } + return Objects.equals(this.index, embedding1.index) && Arrays.equals(this.embedding, embedding1.embedding) + && Objects.equals(this.object, embedding1.object); } @Override public int hashCode() { - int result = Objects.hash(index, object); - result = 31 * result + Arrays.hashCode(embedding); + int result = Objects.hash(this.index, this.object); + result = 31 * result + Arrays.hashCode(this.embedding); return result; } @Override public String toString() { - return "Embedding{" + "index=" + index + ", embedding=" + Arrays.toString(embedding) + ", object='" + object - + '\'' + '}'; + return "Embedding{" + "index=" + this.index + ", embedding=" + Arrays.toString(this.embedding) + + ", object='" + this.object + '\'' + '}'; } + } /** @@ -274,6 +477,7 @@ public EmbeddingRequest(T input, String model) { public EmbeddingRequest(T input) { this(input, EmbeddingModel.EMBED.getValue()); } + } /** @@ -295,46 +499,6 @@ public record EmbeddingList( // @formatter:on } - /** - * Creates an embedding vector representing the input text or token array. - * @param embeddingRequest The embedding request. - * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - * @param Type of the entity in the data list. Can be a {@link String} or - * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single - * request, You can pass a {@link List} of {@link String} or {@link List} of - * {@link List} of tokens. For example: - * - *

{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
- */ - public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple - // inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.input(), "The input can not be null."); - Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, - "The input must be either a String, or a List of Strings or List of List of integers."); - - // The input must not an empty string, and any array must be 1024 dimensions or - // less. - if (embeddingRequest.input() instanceof List list) { - Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); - Assert.isTrue(list.size() <= 1024, "The list must be 1024 dimensions or less"); - Assert.isTrue( - list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, - "The input must be either a String, or a List of Strings or list of list of integers."); - } - - return this.restClient.post() - .uri("/v1/embeddings") - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - /** * Creates a model request for chat conversation. * @@ -472,7 +636,9 @@ public enum ToolChoice { */ @JsonInclude(Include.NON_NULL) public record ResponseFormat(@JsonProperty("type") String type) { + } + } /** @@ -547,6 +713,7 @@ public enum Role { @JsonInclude(Include.NON_NULL) public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("function") ChatCompletionFunction function) { + } /** @@ -559,36 +726,8 @@ public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") Stri @JsonInclude(Include.NON_NULL) public record ChatCompletionFunction(@JsonProperty("name") String name, @JsonProperty("arguments") String arguments) { - } - } - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - - // @formatter:off - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("model_length") MODEL_LENGTH, - /** - * - */ - @JsonProperty("error") ERROR, - /** - * The model requested a tool call. - */ - @JsonProperty("tool_calls") TOOL_CALLS - // @formatter:on + } } @@ -632,6 +771,7 @@ public record Choice( @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } + } /** @@ -676,8 +816,11 @@ public record Content(@JsonProperty("token") String token, @JsonProperty("logpro @JsonInclude(Include.NON_NULL) public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob, @JsonProperty("bytes") List probBytes) { + } + } + } /** @@ -719,132 +862,7 @@ public record ChunkChoice( @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } - } - /** - * List of well-known Mistral chat models. - * https://docs.mistral.ai/platform/endpoints/#mistral-ai-generative-models - * - *

- * Mistral AI provides two types of models: open-weights models (Mistral 7B, Mixtral - * 8x7B, Mixtral 8x22B) and optimized commercial models (Mistral Small, Mistral - * Medium, Mistral Large, and Mistral Embeddings). - */ - public enum ChatModel implements ChatModelDescription { - - // @formatter:off - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MISTRAL_7B - TINY("open-mistral-7b"), - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MIXTRAL_7B - MIXTRAL("open-mixtral-8x7b"), - OPEN_MISTRAL_7B("open-mistral-7b"), - OPEN_MIXTRAL_7B("open-mixtral-8x7b"), - OPEN_MIXTRAL_22B("open-mixtral-8x22b"), - SMALL("mistral-small-latest"), - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Mistral is removing this model - MEDIUM("mistral-medium-latest"), - LARGE("mistral-large-latest"); - // @formatter:on - - private final String value; - - ChatModel(String value) { - this.value = value; - } - - public String getValue() { - return this.value; - } - - @Override - public String getName() { - return this.value; - } - - } - - /** - * List of well-known Mistral embedding models. - * https://docs.mistral.ai/platform/endpoints/#mistral-ai-embedding-model - */ - public enum EmbeddingModel { - - // @formatter:off - EMBED("mistral-embed"); - // @formatter:on - - private final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return this.value; - } - - } - - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code - * and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/chat/completions") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - private MistralAiStreamFunctionCallingHelper chunkMerger = new MistralAiStreamFunctionCallingHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri("/v1/chat/completions") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - .concatMapIterable(window -> { - Mono mono1 = window.reduce(new ChatCompletionChunk(null, null, null, null, null), - (previous, current) -> this.chunkMerger.merge(previous, current)); - return List.of(mono1); - }) - .flatMap(mono -> mono); } } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java index 774bd072934..c00249eead8 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api; import java.util.ArrayList; @@ -105,7 +106,7 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { String content = (current.content() != null ? current.content() - : "" + ((previous.content() != null) ? previous.content() : "")); + : (previous.content() != null) ? previous.content() : ""); Role role = (current.role() != null ? current.role() : previous.role()); role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null String name = (current.name() != null ? current.name() : previous.name()); @@ -198,4 +199,4 @@ public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatComplet } } -// --- \ No newline at end of file +// --- diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java index c89982349e9..dbcc9a9d49b 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.mistralai.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -13,10 +29,6 @@ */ public class MistralAiUsage implements Usage { - public static MistralAiUsage from(MistralAiApi.Usage usage) { - return new MistralAiUsage(usage); - } - private final MistralAiApi.Usage usage; protected MistralAiUsage(MistralAiApi.Usage usage) { @@ -24,6 +36,10 @@ protected MistralAiUsage(MistralAiApi.Usage usage) { this.usage = usage; } + public static MistralAiUsage from(MistralAiApi.Usage usage) { + return new MistralAiUsage(usage); + } + protected MistralAiApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java index 618ca25a633..26b329137f4 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.Arrays; @@ -27,11 +28,8 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.mistralai.api.MistralAiApi; @@ -57,14 +55,11 @@ class MistralAiChatClientIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { - } - @Test void call() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() - .system(s -> s.text(systemTextResource) + ChatResponse response = ChatClient.create(this.chatModel).prompt() + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -81,8 +76,8 @@ void call() { void testMessageHistory() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() - .system(s -> s.text(systemTextResource) + ChatResponse response = ChatClient.create(this.chatModel).prompt() + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -92,7 +87,7 @@ void testMessageHistory() { assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); // @formatter:off - response = ChatClient.create(chatModel).prompt() + response = ChatClient.create(this.chatModel).prompt() .messages(List.of(new UserMessage("Dummy"), response.getResult().getOutput())) .user("Repeat the last assistant message.") .call() @@ -107,7 +102,7 @@ void testMessageHistory() { @Test void listOutputConverterString() { // @formatter:off - List collection = ChatClient.create(chatModel).prompt() + List collection = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -122,7 +117,7 @@ void listOutputConverterString() { void listOutputConverterBean() { // @formatter:off - List actorsFilms = ChatClient.create(chatModel).prompt() + List actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.") .call() .entity(new ParameterizedTypeReference>() { @@ -139,7 +134,7 @@ void customOutputConverter() { var toStringListConverter = new ListOutputConverter(new DefaultConversionService()); // @formatter:off - List flavors = ChatClient.create(chatModel).prompt() + List flavors = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List 10 {subject}") .param("subject", "ice cream flavors")) .call() @@ -154,7 +149,7 @@ void customOutputConverter() { @Test void mapOutputConverter() { // @formatter:off - Map result = ChatClient.create(chatModel).prompt() + Map result = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("Provide me a List of {subject}") .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'")) .call() @@ -169,7 +164,7 @@ void mapOutputConverter() { void beanOutputConverter() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography for a random actor.") .call() .entity(ActorsFilms.class); @@ -183,7 +178,7 @@ void beanOutputConverter() { void beanOutputConverterRecords() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks.") .call() .entity(ActorsFilms.class); @@ -200,7 +195,7 @@ void beanStreamOutputConverterRecords() { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); // @formatter:off - Flux chatResponse = ChatClient.create(chatModel) + Flux chatResponse = ChatClient.create(this.chatModel) .prompt() .user(u -> u .text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator() @@ -226,7 +221,7 @@ void beanStreamOutputConverterRecords() { void functionCallTest() { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).withToolChoice(ToolChoice.AUTO).build()) .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) @@ -245,7 +240,7 @@ void functionCallTest() { void defaultFunctionCallTest() { // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultOptions(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build()) .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) @@ -264,7 +259,7 @@ void defaultFunctionCallTest() { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build()) .user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.") .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) @@ -284,7 +279,7 @@ void streamFunctionCallTest() { void validateCallResponseMetadata() { String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getName(); // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -299,4 +294,8 @@ void validateCallResponseMetadata() { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilms(String actor, List movies) { + + } + } \ No newline at end of file diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java index 29ffdb75a41..d4efe06609b 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import org.junit.jupiter.api.Test; @@ -37,7 +38,7 @@ public class MistralAiChatCompletionRequestTest { @Test void chatCompletionDefaultRequestTest() { - var request = chatModel.createRequest(new Prompt("test content"), false); + var request = this.chatModel.createRequest(new Prompt("test content"), false); assertThat(request.messages()).hasSize(1); assertThat(request.topP()).isEqualTo(1); @@ -52,7 +53,7 @@ void chatCompletionRequestWithOptionsTest() { var options = MistralAiChatOptions.builder().withTemperature(0.5).withTopP(0.8).build(); - var request = chatModel.createRequest(new Prompt("test content", options), true); + var request = this.chatModel.createRequest(new Prompt("test content", options), true); assertThat(request.messages().size()).isEqualTo(1); assertThat(request.topP()).isEqualTo(0.8); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java index 0ec22331deb..9e3b2823020 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.ArrayList; @@ -27,13 +28,13 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -66,9 +67,6 @@ class MistralAiChatModelIT { @Autowired protected StreamingChatModel streamingChatModel; - @Value("classpath:/prompts/system-message.st") - private Resource systemResource; - @Value("classpath:/prompts/eval/qa-evaluator-accurate-answer.st") protected Resource qaEvaluatorAccurateAnswerResource; @@ -81,16 +79,19 @@ class MistralAiChatModelIT { @Value("classpath:/prompts/eval/user-evaluator-message.st") protected Resource userEvaluatorResource; + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + @Test void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); // NOTE: Mistral expects the system message to be before the user message or will // fail with 400 error. Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -126,16 +127,13 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -148,7 +146,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -169,7 +167,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -202,7 +200,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -225,7 +223,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = streamingChatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -240,4 +238,8 @@ void streamFunctionCallTest() { assertThat(content).containsAnyOf("10.0", "10"); } + record ActorsFilmsRecord(String actor, List movies) { + + } + } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java index 4631ab807cd..a6a42311d55 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.common.KeyValue; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class MistralAiChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -76,7 +78,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -97,7 +99,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -118,7 +120,7 @@ void observationForStreamingChatOperation() { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java index de378c57a5a..b9c91cca8a9 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.List; @@ -35,21 +36,21 @@ class MistralAiEmbeddingIT { @Test void defaultEmbedding() { - assertThat(mistralAiEmbeddingModel).isNotNull(); - var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World")); + assertThat(this.mistralAiEmbeddingModel).isNotNull(); + var embeddingResponse = this.mistralAiEmbeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); - assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); } @Test void embeddingTest() { - assertThat(mistralAiEmbeddingModel).isNotNull(); - var embeddingResponse = mistralAiEmbeddingModel.call(new EmbeddingRequest( + assertThat(this.mistralAiEmbeddingModel).isNotNull(); + var embeddingResponse = this.mistralAiEmbeddingModel.call(new EmbeddingRequest( List.of("Hello World", "World is big"), MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build())); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -58,7 +59,7 @@ void embeddingTest() { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(9); assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(9); - assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); } } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java index 55f634c0ca1..813c62bbcd1 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -33,8 +37,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -63,13 +65,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java index 28b29fcd378..1818d2e43c7 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.mistralai; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.when; +package org.springframework.ai.mistralai; import java.util.List; import java.util.Optional; @@ -29,6 +25,8 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.mistralai.api.MistralAiApi; @@ -49,7 +47,10 @@ import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; /** * @author Christian Tzolov @@ -59,25 +60,6 @@ @ExtendWith(MockitoExtension.class) public class MistralAiRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -90,21 +72,21 @@ public void onError(RetryContext context, RetryCallback @BeforeEach public void beforeEach() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); - chatModel = new MistralAiChatModel(mistralAiApi, + this.chatModel = new MistralAiChatModel(this.mistralAiApi, MistralAiChatOptions.builder() .withTemperature(0.7) .withTopP(1.0) .withSafePrompt(false) .withModel(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) .build(), - null, retryTemplate); - embeddingModel = new MistralAiEmbeddingModel(mistralAiApi, MetadataMode.EMBED, + null, this.retryTemplate); + this.embeddingModel = new MistralAiEmbeddingModel(this.mistralAiApi, MetadataMode.EMBED, MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(), - retryTemplate); + this.retryTemplate); } @Test @@ -112,27 +94,27 @@ public void mistralAiChatTransientError() { var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), ChatCompletionFinishReason.STOP, null); - ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model", + ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model", List.of(choice), new MistralAiApi.Usage(10, 10, 10)); - when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void mistralAiChatNonTransientError() { - when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -141,28 +123,28 @@ public void mistralAiChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), ChatCompletionFinishReason.STOP, null); - ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l, + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789L, "model", List.of(choice)); - when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test @Disabled("Currently stream() does not implement retry") public void mistralAiChatStreamNonTransientError() { - when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text"))); } @Test @@ -171,26 +153,45 @@ public void mistralAiEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new MistralAiApi.Usage(10, 10, 10)); - when(mistralAiApi.embeddings(isA(EmbeddingRequest.class))) + when(this.mistralAiApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingModel + var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void mistralAiEmbeddingNonTransientError() { - when(mistralAiApi.embeddings(isA(EmbeddingRequest.class))) + when(this.mistralAiApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingModel + assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java index b48373f9fcf..608eccca55a 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MockWeatherService.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MockWeatherService.java index 4a6f594a95d..0c7c4dacf47 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MockWeatherService.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.function.Function; @@ -28,14 +29,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -57,34 +65,29 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - - @Override - public Response apply(Request request) { - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java index ed698b9f4da..1e16b590cf4 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java index 523ac4df2d1..48f57e8b9e7 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api; import java.util.List; @@ -45,7 +46,7 @@ public class MistralAiApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false)); assertThat(response).isNotNull(); @@ -62,7 +63,7 @@ void chatCompletionEntityWithSystemMessage() { You should reply to the user's request with your name and also in the style of a pirate. """, Role.SYSTEM); - ResponseEntity response = mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( List.of(systemMessage, userMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false)); assertThat(response).isNotNull(); @@ -72,7 +73,7 @@ void chatCompletionEntityWithSystemMessage() { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = mistralAiApi.chatCompletionStream(new ChatCompletionRequest( + Flux response = this.mistralAiApi.chatCompletionStream(new ChatCompletionRequest( List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, true)); assertThat(response).isNotNull(); @@ -81,7 +82,7 @@ void chatCompletionStream() { @Test void embeddings() { - ResponseEntity> response = mistralAiApi + ResponseEntity> response = this.mistralAiApi .embeddings(new MistralAiApi.EmbeddingRequest("Hello world")); assertThat(response).isNotNull(); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java index 4d23255a6b3..b4ea08ed0d4 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api.tool; import java.util.ArrayList; @@ -31,8 +32,8 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall; -import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool.Type; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.http.ResponseEntity; @@ -46,14 +47,23 @@ @Disabled public class MistralAiApiToolFunctionCallIT { + static final String MISTRAL_AI_CHAT_MODEL = MistralAiApi.ChatModel.LARGE.getValue(); + private final Logger logger = LoggerFactory.getLogger(MistralAiApiToolFunctionCallIT.class); MockWeatherService weatherService = new MockWeatherService(); - static final String MISTRAL_AI_CHAT_MODEL = MistralAiApi.ChatModel.LARGE.getValue(); - MistralAiApi completionApi = new MistralAiApi(System.getenv("MISTRAL_AI_API_KEY")); + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @Test @SuppressWarnings("null") public void toolFunctionCall() throws JsonProcessingException { @@ -100,7 +110,7 @@ public void toolFunctionCall() throws JsonProcessingException { System.out .println(new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(chatCompletionRequest)); - ResponseEntity chatCompletion = completionApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.completionApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -123,7 +133,7 @@ public void toolFunctionCall() throws JsonProcessingException { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), @@ -133,10 +143,10 @@ public void toolFunctionCall() throws JsonProcessingException { var functionResponseRequest = new ChatCompletionRequest(messages, MISTRAL_AI_CHAT_MODEL, 0.8); - ResponseEntity chatCompletion2 = completionApi + ResponseEntity chatCompletion2 = this.completionApi .chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(chatCompletion2.getBody().choices()).isNotEmpty(); @@ -145,21 +155,10 @@ public void toolFunctionCall() throws JsonProcessingException { .containsAnyOf("30.0°C", "30°C"); assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("Tokyo") .containsAnyOf("10.0°C", "10°C"); - ; assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("Paris") .containsAnyOf("15.0°C", "15°C"); - ; } } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MockWeatherService.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MockWeatherService.java index 1c7c0d4de17..c468dffbc9e 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MockWeatherService.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api.tool; import java.util.function.Function; @@ -28,14 +29,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -57,34 +65,29 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - - @Override - public Response apply(Request request) { - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/PaymentStatusFunctionCallingIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/PaymentStatusFunctionCallingIT.java index a7a3deffc8e..dc144870183 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/PaymentStatusFunctionCallingIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/PaymentStatusFunctionCallingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api.tool; import java.util.ArrayList; @@ -37,7 +38,6 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool.Type; -// import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.http.ResponseEntity; import static org.assertj.core.api.Assertions.assertThat; @@ -55,46 +55,25 @@ @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") public class PaymentStatusFunctionCallingIT { - private final Logger logger = LoggerFactory.getLogger(PaymentStatusFunctionCallingIT.class); - // Assuming we have the following data public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); - record StatusDate(String status, String date) { - } - - public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { - } - - public record Status(@JsonProperty(required = true, value = "status") String status) { - } - - public record Date(@JsonProperty(required = true, value = "date") String date) { - } + static Map> functions = Map.of("retrieve_payment_status", + new RetrievePaymentStatus(), "retrieve_payment_date", new RetrievePaymentDate()); - private static class RetrievePaymentStatus implements Function { + private final Logger logger = LoggerFactory.getLogger(PaymentStatusFunctionCallingIT.class); - @Override - public Status apply(Transaction paymentTransaction) { - return new Status(DATA.get(paymentTransaction.transactionId).status); + private static T jsonToObject(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); } - - } - - private static class RetrievePaymentDate implements Function { - - @Override - public Date apply(Transaction paymentTransaction) { - return new Date(DATA.get(paymentTransaction.transactionId).date); + catch (JsonProcessingException e) { + throw new RuntimeException(e); } - } - static Map> functions = Map.of("retrieve_payment_status", - new RetrievePaymentStatus(), "retrieve_payment_date", new RetrievePaymentDate()); - @Test @SuppressWarnings("null") public void toolFunctionCall() throws JsonProcessingException { @@ -157,19 +136,44 @@ public void toolFunctionCall() throws JsonProcessingException { .chatCompletionEntity(new ChatCompletionRequest(messages, MistralAiApi.ChatModel.LARGE.getValue())); var responseContent = response.getBody().choices().get(0).message().content(); - logger.info("Final response: " + responseContent); + this.logger.info("Final response: " + responseContent); assertThat(responseContent).containsIgnoringCase("T1001"); assertThat(responseContent).containsIgnoringCase("Paid"); } - private static T jsonToObject(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); + record StatusDate(String status, String date) { + + } + + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + public record Date(@JsonProperty(required = true, value = "date") String date) { + + } + + private static class RetrievePaymentStatus implements Function { + + @Override + public Status apply(Transaction paymentTransaction) { + return new Status(DATA.get(paymentTransaction.transactionId).status); } - catch (JsonProcessingException e) { - throw new RuntimeException(e); + + } + + private static class RetrievePaymentDate implements Function { + + @Override + public Date apply(Transaction paymentTransaction) { + return new Date(DATA.get(paymentTransaction.transactionId).date); } + } } diff --git a/models/spring-ai-moonshot/pom.xml b/models/spring-ai-moonshot/pom.xml index 6f7f92b5238..84a154850f7 100644 --- a/models/spring-ai-moonshot/pom.xml +++ b/models/spring-ai-moonshot/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index bc5c3d13d1f..2751a9d6498 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -60,14 +70,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; /** * @author Geng Rong @@ -158,6 +160,21 @@ public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options, this.observationRegistry = observationRegistry; } + private static Generation buildGeneration(Choice choice, Map metadata) { + List toolCalls = choice.message().toolCalls() == null ? List.of() + : choice.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", + toolCall.function().name(), toolCall.function().arguments())) + .toList(); + + var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + return new Generation(assistantMessage, generationMetadata); + } + @Override public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); @@ -305,21 +322,6 @@ private ChatResponseMetadata from(ChatCompletion result) { .build(); } - private static Generation buildGeneration(Choice choice, Map metadata) { - List toolCalls = choice.message().toolCalls() == null ? List.of() - : choice.message() - .toolCalls() - .stream() - .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", - toolCall.function().name(), toolCall.function().arguments())) - .toList(); - - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); - String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); - return new Generation(assistantMessage, generationMetadata); - } - /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index b5dd8109795..e5bae8560ab 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; @@ -25,12 +33,6 @@ import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - /** * @author Geng Rong * @author Thomas Vitale @@ -145,6 +147,10 @@ public class MoonshotChatOptions implements FunctionCallingOptions, ChatOptions @JsonIgnore private Map toolContext; + public static Builder builder() { + return new Builder(); + } + @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -157,122 +163,13 @@ public void setFunctionCallbacks(List functionCallbacks) { @Override public Set getFunctions() { - return functions; + return this.functions; } public void setFunctions(Set functionNames) { this.functions = functionNames; } - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - protected MoonshotChatOptions options; - - public Builder() { - this.options = new MoonshotChatOptions(); - } - - public Builder(MoonshotChatOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withN(Integer n) { - this.options.n = n; - return this; - } - - public Builder withPresencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } - - public Builder withFrequencyPenalty(Double frequencyPenalty) { - this.options.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withUser(String user) { - this.options.user = user; - return this; - } - - public Builder withTools(List tools) { - this.options.tools = tools; - return this; - } - - public Builder withToolChoice(String toolChoice) { - this.options.toolChoice = toolChoice; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public MoonshotChatOptions build() { - return this.options; - } - - } - @Override public String getModel() { return this.model; @@ -411,93 +308,220 @@ public MoonshotChatOptions copy() { public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((n == null) ? 0 : n.hashCode()); - result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((user == null) ? 0 : user.hashCode()); - result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); - result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.n == null) ? 0 : this.n.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); + result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); + result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); return result; } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } MoonshotChatOptions other = (MoonshotChatOptions) obj; if (this.model == null) { - if (other.model != null) + if (other.model != null) { return false; + } } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; + } if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) + if (other.frequencyPenalty != null) { return false; + } } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { return false; + } if (this.maxTokens == null) { - if (other.maxTokens != null) + if (other.maxTokens != null) { return false; + } } - else if (!this.maxTokens.equals(other.maxTokens)) + else if (!this.maxTokens.equals(other.maxTokens)) { return false; + } if (this.n == null) { - if (other.n != null) + if (other.n != null) { return false; + } } - else if (!this.n.equals(other.n)) + else if (!this.n.equals(other.n)) { return false; + } if (this.presencePenalty == null) { - if (other.presencePenalty != null) + if (other.presencePenalty != null) { return false; + } } - else if (!this.presencePenalty.equals(other.presencePenalty)) + else if (!this.presencePenalty.equals(other.presencePenalty)) { return false; + } if (this.stop == null) { - if (other.stop != null) + if (other.stop != null) { return false; + } } - else if (!stop.equals(other.stop)) + else if (!this.stop.equals(other.stop)) { return false; + } if (this.temperature == null) { - if (other.temperature != null) + if (other.temperature != null) { return false; + } } - else if (!this.temperature.equals(other.temperature)) + else if (!this.temperature.equals(other.temperature)) { return false; + } if (this.topP == null) { - if (other.topP != null) + if (other.topP != null) { return false; + } } - else if (!topP.equals(other.topP)) + else if (!this.topP.equals(other.topP)) { return false; + } if (this.user == null) { return other.user == null; } - else if (!this.user.equals(other.user)) + else if (!this.user.equals(other.user)) { return false; + } if (this.proxyToolCalls == null) { return other.proxyToolCalls == null; } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { return false; + } if (this.toolContext == null) { return other.toolContext == null; } - else if (!this.toolContext.equals(other.toolContext)) + else if (!this.toolContext.equals(other.toolContext)) { return false; + } return true; } + public static class Builder { + + protected MoonshotChatOptions options; + + public Builder() { + this.options = new MoonshotChatOptions(); + } + + public Builder(MoonshotChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withN(Integer n) { + this.options.n = n; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withUser(String user) { + this.options.user = user; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public MoonshotChatOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java index 0ae4fccfe9a..7f8a3a27bd9 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.aot; import org.springframework.ai.moonshot.api.MoonshotApi; diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java index 43050b0252e..f6eb1c476c7 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.api; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import java.util.function.Predicate; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; @@ -29,14 +39,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import java.util.function.Predicate; import static org.springframework.ai.moonshot.api.MoonshotConstants.DEFAULT_BASE_URL; @@ -102,6 +104,147 @@ public MoonshotApi(String baseUrl, String moonshotApiKey, RestClient.Builder res this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); } + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v1/chat/completions") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + // cancels the flux stream after the "[DONE]" is received. + .takeUntil(SSE_DONE_PREDICATE) + // filters out the "[DONE]" message. + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + // Detect is the chunk is part of a streaming function call. + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + // Group all chunks belonging to the same function call. + // Flux -> Flux> + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + // Reduce the inner Flux window into a single + // Mono, + // Flux> -> Flux> + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + // Flux> -> Flux + .flatMap(mono -> mono); + } + + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") + STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") + LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") + CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") + TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") + FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") + TOOL_CALL + + } + + /** + * Moonshot Chat Completion Models: + * + *

    + *
  • MOONSHOT_V1_8K - moonshot-v1-8k
  • + *
  • MOONSHOT_V1_32K - moonshot-v1-32k
  • + *
  • MOONSHOT_V1_128K - moonshot-v1-128k
  • + *
+ */ + public enum ChatModel implements ChatModelDescription { + + // @formatter:off + MOONSHOT_V1_8K("moonshot-v1-8k"), + MOONSHOT_V1_32K("moonshot-v1-32k"), + MOONSHOT_V1_128K("moonshot-v1-128k"); + // @formatter:on + + private final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + @Override + public String getName() { + return this.value; + } + + } + /** * Usage statistics. * @@ -252,6 +395,7 @@ public static Object function(String functionName) { } } + } /** @@ -275,6 +419,16 @@ public record ChatCompletionMessage( // @formatter:on ) { + /** + * Create a chat completion message with the given content and role. All other + * fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null); + } + /** * Get message content as String. */ @@ -288,16 +442,6 @@ public String content() { throw new IllegalStateException("The content is not a string!"); } - /** - * Create a chat completion message with the given content and role. All other - * fields are null. - * @param content The contents of the message. - * @param role The role of the author of this message. - */ - public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null); - } - /** * The role of the author of this message. NOTE: Moonshot expects the system * message to be before the user message or will fail with 400 error. @@ -340,6 +484,7 @@ public enum Role { @JsonInclude(Include.NON_NULL) public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("function") ChatCompletionFunction function) { + } /** @@ -352,44 +497,8 @@ public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") Stri @JsonInclude(Include.NON_NULL) public record ChatCompletionFunction(@JsonProperty("name") String name, @JsonProperty("arguments") String arguments) { - } - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") - STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") - LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("content_filter") - CONTENT_FILTER, - /** - * The model called a tool. - */ - @JsonProperty("tool_calls") - TOOL_CALLS, - /** - * (deprecated) The model called a function. - */ - @JsonProperty("function_call") - FUNCTION_CALL, - /** - * Only for compatibility with Mistral AI API. - */ - @JsonProperty("tool_call") - TOOL_CALL + } } @@ -431,6 +540,7 @@ public record Choice( @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) { // @formatter:on } + } /** @@ -471,39 +581,7 @@ public record ChunkChoice( @JsonProperty("usage") Usage usage // @formatter:on ) { - } - } - - /** - * Moonshot Chat Completion Models: - * - *
    - *
  • MOONSHOT_V1_8K - moonshot-v1-8k
  • - *
  • MOONSHOT_V1_32K - moonshot-v1-32k
  • - *
  • MOONSHOT_V1_128K - moonshot-v1-128k
  • - *
- */ - public enum ChatModel implements ChatModelDescription { - - // @formatter:off - MOONSHOT_V1_8K("moonshot-v1-8k"), - MOONSHOT_V1_32K("moonshot-v1-32k"), - MOONSHOT_V1_128K("moonshot-v1-128k"); - // @formatter:on - - private final String value; - ChatModel(String value) { - this.value = value; - } - - public String getValue() { - return this.value; - } - - @Override - public String getName() { - return this.value; } } @@ -564,76 +642,9 @@ public record Function(@JsonProperty("description") String description, @JsonPro public Function(String description, String name, String jsonSchema) { this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema)); } - } - } - - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code - * and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/chat/completions") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); - AtomicBoolean isInsideTool = new AtomicBoolean(false); + } - return this.webClient.post() - .uri("/v1/chat/completions") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - // cancels the flux stream after the "[DONE]" is received. - .takeUntil(SSE_DONE_PREDICATE) - // filters out the "[DONE]" message. - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - // Detect is the chunk is part of a streaming function call. - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - // Group all chunks belonging to the same function call. - // Flux -> Flux> - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - // Reduce the inner Flux window into a single - // Mono, - // Flux> -> Flux> - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null), - (previous, current) -> this.chunkMerger.merge(previous, current)); - return List.of(monoChunk); - }) - // Flux> -> Flux - .flatMap(mono -> mono); } } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java index c2aea6c055a..3d6bdd4b272 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.api; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java index 5afff821618..06f1dc7655d 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java @@ -1,5 +1,24 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moonshot.api; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionFinishReason; @@ -9,9 +28,6 @@ import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.ToolCall; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.List; - /** * Helper class to support Streaming function calling. It can merge the streamed * ChatCompletionChunk in case of function calling message. diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java index 5d5fadb1780..3fb67358a61 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moonshot.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -11,15 +27,15 @@ public class MoonshotUsage implements Usage { private final MoonshotApi.Usage usage; - public static MoonshotUsage from(MoonshotApi.Usage usage) { - return new MoonshotUsage(usage); - } - protected MoonshotUsage(MoonshotApi.Usage usage) { Assert.notNull(usage, "Moonshot Usage must not be null"); this.usage = usage; } + public static MoonshotUsage from(MoonshotApi.Usage usage) { + return new MoonshotUsage(usage); + } + protected MoonshotApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatCompletionRequestTest.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatCompletionRequestTest.java index 89751f8c0c7..89b4781226f 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatCompletionRequestTest.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatCompletionRequestTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.boot.test.context.SpringBootTest; @@ -34,7 +36,7 @@ public class MoonshotChatCompletionRequestTest { @Test void chatCompletionDefaultRequestTest() { - var request = chatModel.createRequest(new Prompt("test content"), false); + var request = this.chatModel.createRequest(new Prompt("test content"), false); assertThat(request.messages()).hasSize(1); assertThat(request.topP()).isEqualTo(1); @@ -46,7 +48,7 @@ void chatCompletionDefaultRequestTest() { @Test void chatCompletionRequestWithOptionsTest() { var options = MoonshotChatOptions.builder().withTemperature(0.5).withTopP(0.8).build(); - var request = chatModel.createRequest(new Prompt("test content", options), true); + var request = this.chatModel.createRequest(new Prompt("test content", options), true); assertThat(request.messages().size()).isEqualTo(1); assertThat(request.topP()).isEqualTo(0.8); diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java index 4e1df77fab8..e87a1122796 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; +import java.util.List; +import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion; @@ -35,10 +41,6 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -52,25 +54,6 @@ @ExtendWith(MockitoExtension.class) public class MoonshotRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private @Mock MoonshotApi moonshotApi; @@ -80,10 +63,10 @@ public void onError(RetryContext context, RetryCallback @BeforeEach public void beforeEach() { RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryListener = new TestRetryListener(); + retryTemplate.registerListener(this.retryListener); - chatModel = new MoonshotChatModel(moonshotApi, + this.chatModel = new MoonshotChatModel(this.moonshotApi, MoonshotChatOptions.builder() .withTemperature(0.7) .withTopP(1.0) @@ -100,24 +83,24 @@ public void moonshotChatTransientError() { ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model", List.of(choice), new MoonshotApi.Usage(10, 10, 10)); - when(moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void moonshotChatNonTransientError() { - when(moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -128,24 +111,43 @@ public void moonshotChatStreamTransientError() { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l, "model", List.of(choice)); - when(moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void moonshotChatStreamNonTransientError() { - when(moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).collectList().block()); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotTestConfiguration.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotTestConfiguration.java index 11db99be985..60a91076983 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotTestConfiguration.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; import org.springframework.ai.moonshot.api.MoonshotApi; diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java index e6015951db6..60bb11f0828 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java index 6c3619fdb38..402409649ea 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,31 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.api; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Geng Rong */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } /** @@ -65,28 +71,25 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } -} \ No newline at end of file +} diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiIT.java index 6c6166db16e..e7ca21d9250 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.api; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.Role; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest; import org.springframework.http.ResponseEntity; -import reactor.core.publisher.Flux; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -40,7 +42,7 @@ public class MoonshotApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = moonshotApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.moonshotApi.chatCompletionEntity(new ChatCompletionRequest( List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, false)); assertThat(response).isNotNull(); @@ -57,7 +59,7 @@ void chatCompletionEntityWithSystemMessage() { You should reply to the user's request with your name and also in the style of a pirate. """, Role.SYSTEM); - ResponseEntity response = moonshotApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.moonshotApi.chatCompletionEntity(new ChatCompletionRequest( List.of(systemMessage, userMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, false)); assertThat(response).isNotNull(); @@ -67,7 +69,7 @@ void chatCompletionEntityWithSystemMessage() { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = moonshotApi.chatCompletionStream(new ChatCompletionRequest( + Flux response = this.moonshotApi.chatCompletionStream(new ChatCompletionRequest( List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, true)); assertThat(response).isNotNull(); diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java index 7c3764afe57..fa2cc486346 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,17 @@ package org.springframework.ai.moonshot.api; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.Role; @@ -32,10 +37,6 @@ import org.springframework.ai.moonshot.api.MoonshotApi.FunctionTool.Type; import org.springframework.http.ResponseEntity; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -44,12 +45,6 @@ @EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+") public class MoonshotApiToolFunctionCallIT { - private final Logger logger = LoggerFactory.getLogger(MoonshotApiToolFunctionCallIT.class); - - private final MockWeatherService weatherService = new MockWeatherService(); - - private final MoonshotApi moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY")); - private static final FunctionTool FUNCTION_TOOL = new FunctionTool(Type.FUNCTION, new FunctionTool.Function( "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """ { @@ -76,6 +71,21 @@ public class MoonshotApiToolFunctionCallIT { } """)); + private final Logger logger = LoggerFactory.getLogger(MoonshotApiToolFunctionCallIT.class); + + private final MockWeatherService weatherService = new MockWeatherService(); + + private final MoonshotApi moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY")); + + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("null") @Test public void toolFunctionCall() { @@ -97,7 +107,7 @@ private void toolFunctionCall(String userMessage, String cityName) { ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), List.of(FUNCTION_TOOL), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion = moonshotApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.moonshotApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -116,7 +126,7 @@ private void toolFunctionCall(String userMessage, String cityName) { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, @@ -127,9 +137,9 @@ private void toolFunctionCall(String userMessage, String cityName) { var functionResponseRequest = new ChatCompletionRequest(messages, MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.5); - ResponseEntity chatCompletion2 = moonshotApi.chatCompletionEntity(functionResponseRequest); + ResponseEntity chatCompletion2 = this.moonshotApi.chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); @@ -138,13 +148,4 @@ private void toolFunctionCall(String userMessage, String cityName) { .containsAnyOf("30"); } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/ActorsFilms.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/ActorsFilms.java index d4436cbb7d7..c0ec33e70af 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/ActorsFilms.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/ActorsFilms.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.chat; import java.util.List; @@ -30,7 +31,7 @@ public ActorsFilms() { } public String getActor() { - return actor; + return this.actor; } public void setActor(String actor) { @@ -38,7 +39,7 @@ public void setActor(String actor) { } public List getMovies() { - return movies; + return this.movies; } public void setMovies(List movies) { @@ -47,7 +48,7 @@ public void setMovies(List movies) { @Override public String toString() { - return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}'; + return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}'; } } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java index 91a4bc9b3c3..8fa54687b2f 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.chat; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -33,12 +41,6 @@ import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -68,7 +70,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -92,7 +94,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -108,4 +110,4 @@ void streamFunctionCallTest() { assertThat(content).contains("30", "10", "15"); } -} \ No newline at end of file +} diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java index f0b4b794451..83222c75d23 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.chat; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -39,11 +46,6 @@ import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.Resource; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -68,10 +70,10 @@ public class MoonshotChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -114,7 +116,7 @@ void mapOutputConverter() { "numbers": [1, 2, 3] }""", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -133,15 +135,12 @@ void beanOutputConverter() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -161,7 +160,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -184,7 +183,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -200,4 +199,8 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java index 9e4239a5ce8..f13a46f4ad1 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class MoonshotChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -79,7 +81,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -102,7 +104,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -123,7 +125,7 @@ void observationForStreamingChatOperation() { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-oci-genai/pom.xml b/models/spring-ai-oci-genai/pom.xml index b4ab01cfea2..64a474969fd 100644 --- a/models/spring-ai-oci-genai/pom.xml +++ b/models/spring-ai-oci-genai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java index 123e0705f36..e3658a226a7 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.oci; import java.util.ArrayList; @@ -28,6 +29,7 @@ import com.oracle.bmc.generativeaiinference.model.ServingMode; import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest; import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -83,7 +85,7 @@ public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions option @Override public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); - OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), options); + OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), this.options); List embedTextRequests = createRequests(request.getInstructions(), runtimeOptions); EmbeddingModelObservationContext context = EmbeddingModelObservationContext.builder() @@ -109,7 +111,7 @@ private EmbeddingResponse embedAllWithContext(List embedTextRe AtomicInteger index = new AtomicInteger(0); List embeddings = new ArrayList<>(); for (EmbedTextRequest embedTextRequest : embedTextRequests) { - EmbedTextResult embedTextResult = genAi.embedText(embedTextRequest).getEmbedTextResult(); + EmbedTextResult embedTextResult = this.genAi.embedText(embedTextRequest).getEmbedTextResult(); if (modelId == null) { modelId = embedTextResult.getModelId(); } diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java index e72f5359f5f..3f5641a6e23 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.oci; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -40,40 +42,14 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final OCIEmbeddingOptions options = new OCIEmbeddingOptions(); - - public Builder withModel(String model) { - this.options.setModel(model); - return this; - } - - public Builder withCompartment(String compartment) { - this.options.setCompartment(compartment); - return this; - } - - public Builder withServingMode(String servingMode) { - this.options.setServingMode(servingMode); - return this; - } - - public Builder withTruncate(EmbedTextDetails.Truncate truncate) { - this.options.truncate = truncate; - return this; - } - - public OCIEmbeddingOptions build() { - return this.options; - } - - } - public String getModel() { return this.model; } + public void setModel(String model) { + this.model = model; + } + /** * Not used by OCI GenAI. * @return null @@ -83,12 +59,8 @@ public Integer getDimensions() { return null; } - public void setModel(String model) { - this.model = model; - } - public String getCompartment() { - return compartment; + return this.compartment; } public void setCompartment(String compartment) { @@ -96,7 +68,7 @@ public void setCompartment(String compartment) { } public String getServingMode() { - return servingMode; + return this.servingMode; } public void setServingMode(String servingMode) { @@ -104,11 +76,41 @@ public void setServingMode(String servingMode) { } public EmbedTextDetails.Truncate getTruncate() { - return truncate; + return this.truncate; } public void setTruncate(EmbedTextDetails.Truncate truncate) { this.truncate = truncate; } + public static class Builder { + + private final OCIEmbeddingOptions options = new OCIEmbeddingOptions(); + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withCompartment(String compartment) { + this.options.setCompartment(compartment); + return this; + } + + public Builder withServingMode(String servingMode) { + this.options.setServingMode(servingMode); + return this; + } + + public Builder withTruncate(EmbedTextDetails.Truncate truncate) { + this.options.truncate = truncate; + return this; + } + + public OCIEmbeddingOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java index 5124bd734ee..b1f6da89b35 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.oci; import java.io.IOException; diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java index 8d25240bf82..586fbfddeeb 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.oci; import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.oci.BaseEmbeddingModelTest.OCI_COMPARTMENT_ID_KEY; -@EnabledIfEnvironmentVariable(named = OCI_COMPARTMENT_ID_KEY, matches = ".+") +@EnabledIfEnvironmentVariable(named = org.springframework.ai.oci.BaseEmbeddingModelTest.OCI_COMPARTMENT_ID_KEY, + matches = ".+") public class OCIEmbeddingModelIT extends BaseEmbeddingModelTest { private final OCIEmbeddingModel embeddingModel = get(); @@ -35,13 +37,13 @@ public class OCIEmbeddingModelIT extends BaseEmbeddingModelTest { @Test void embed() { - float[] embedding = embeddingModel.embed(new Document("How many provinces are in Canada?")); + float[] embedding = this.embeddingModel.embed(new Document("How many provinces are in Canada?")); assertThat(embedding).hasSize(1024); } @Test void call() { - EmbeddingResponse response = embeddingModel.call(new EmbeddingRequest(content, null)); + EmbeddingResponse response = this.embeddingModel.call(new EmbeddingRequest(this.content, null)); assertThat(response).isNotNull(); assertThat(response.getResults()).hasSize(2); assertThat(response.getMetadata().getModel()).isEqualTo(EMBEDDING_MODEL_V2); @@ -49,8 +51,8 @@ void call() { @Test void callWithOptions() { - EmbeddingResponse response = embeddingModel - .call(new EmbeddingRequest(content, OCIEmbeddingOptions.builder().withModel(EMBEDDING_MODEL_V3).build())); + EmbeddingResponse response = this.embeddingModel.call(new EmbeddingRequest(this.content, + OCIEmbeddingOptions.builder().withModel(EMBEDDING_MODEL_V3).build())); assertThat(response).isNotNull(); assertThat(response.getResults()).hasSize(2); assertThat(response.getMetadata().getModel()).isEqualTo(EMBEDDING_MODEL_V3); diff --git a/models/spring-ai-ollama/pom.xml b/models/spring-ai-ollama/pom.xml index 69f64f8c9c8..3b9a4428a7b 100644 --- a/models/spring-ai-ollama/pom.xml +++ b/models/spring-ai-ollama/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 93edabe8c95..c60523c00ee 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; import java.util.Base64; @@ -24,13 +25,19 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.model.*; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; @@ -43,22 +50,20 @@ import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall; import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.ai.ollama.metadata.OllamaChatUsage; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - /** * {@link ChatModel} implementation for {@literal Ollama}. Ollama allows developers to run * large language models and generate embeddings locally. It supports open-source models @@ -96,7 +101,7 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, this.chatApi = ollamaApi; this.defaultOptions = defaultOptions; this.observationRegistry = observationRegistry; - this.modelManager = new OllamaModelManager(chatApi, modelManagementOptions); + this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions); initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy()); } @@ -104,6 +109,22 @@ public static Builder builder() { return new Builder(); } + public static ChatResponseMetadata from(OllamaApi.ChatResponse response) { + Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); + return ChatResponseMetadata.builder() + .withUsage(OllamaChatUsage.from(response)) + .withModel(response.model()) + .withKeyValue("created-at", response.createdAt()) + .withKeyValue("eval-duration", response.evalDuration()) + .withKeyValue("eval-count", response.evalCount()) + .withKeyValue("load-duration", response.loadDuration()) + .withKeyValue("eval-duration", response.promptEvalDuration()) + .withKeyValue("eval-count", response.promptEvalCount()) + .withKeyValue("total-duration", response.totalDuration()) + .withKeyValue("done", response.done()) + .build(); + } + @Override public ChatResponse call(Prompt prompt) { @@ -157,22 +178,6 @@ && isToolCall(response, Set.of("stop"))) { return response; } - public static ChatResponseMetadata from(OllamaApi.ChatResponse response) { - Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); - return ChatResponseMetadata.builder() - .withUsage(OllamaChatUsage.from(response)) - .withModel(response.model()) - .withKeyValue("created-at", response.createdAt()) - .withKeyValue("eval-duration", response.evalDuration()) - .withKeyValue("eval-count", response.evalCount()) - .withKeyValue("load-duration", response.loadDuration()) - .withKeyValue("eval-duration", response.promptEvalDuration()) - .withKeyValue("eval-count", response.promptEvalCount()) - .withKeyValue("total-duration", response.totalDuration()) - .withKeyValue("done", response.done()) - .build(); - } - @Override public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { @@ -435,10 +440,10 @@ public Builder withModelManagementOptions(ModelManagementOptions modelManagement } public OllamaChatModel build() { - return new OllamaChatModel(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks, - observationRegistry, modelManagementOptions); + return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackContext, + this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions); } } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index 7034a9c035f..f44c9c6ea40 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; import java.time.Duration; @@ -22,19 +23,27 @@ import java.util.regex.Pattern; import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.*; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage; import org.springframework.util.Assert; @@ -236,9 +245,10 @@ public Builder withModelManagementOptions(ModelManagementOptions modelManagement } public OllamaEmbeddingModel build() { - return new OllamaEmbeddingModel(ollamaApi, defaultOptions, observationRegistry, modelManagementOptions); + return new OllamaEmbeddingModel(this.ollamaApi, this.defaultOptions, this.observationRegistry, + this.modelManagementOptions); } } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java index 2d89a804eb0..bd8799c9b8b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.aot; import org.springframework.ai.ollama.api.OllamaApi; diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index acd9028d1b7..bbd32c5117b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; import java.io.IOException; @@ -23,8 +24,14 @@ import java.util.Objects; import java.util.function.Consumer; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.boot.context.properties.bind.ConstructorBinding; @@ -39,13 +46,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * Java Client for the Ollama API. https://ollama.ai * @@ -56,40 +56,20 @@ // @formatter:off public class OllamaApi { - private static final Log logger = LogFactory.getLog(OllamaApi.class); - - private static final String DEFAULT_BASE_URL = "http://localhost:11434"; - public static final String PROVIDER_NAME = AiProvider.OLLAMA.value(); public static final String REQUEST_BODY_NULL_ERROR = "The request body can not be null."; + private static final Log logger = LogFactory.getLog(OllamaApi.class); + + private static final String DEFAULT_BASE_URL = "http://localhost:11434"; + private final ResponseErrorHandler responseErrorHandler; private final RestClient restClient; private final WebClient webClient; - private static class OllamaResponseErrorHandler implements ResponseErrorHandler { - - @Override - public boolean hasError(ClientHttpResponse response) throws IOException { - return response.getStatusCode().isError(); - } - - @Override - public void handleError(ClientHttpResponse response) throws IOException { - if (response.getStatusCode().isError()) { - int statusCode = response.getStatusCode().value(); - String statusText = response.getStatusText(); - String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8); - logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message)); - throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message)); - } - } - - } - /** * Default constructor that uses the default localhost url. */ @@ -125,9 +105,223 @@ public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); } + /** + * Generate a completion for the given prompt. + * @param completionRequest Completion request. + * @return Completion response. + * @deprecated Use {@link #chat(ChatRequest)} instead. + */ + @Deprecated(since = "1.0.0-M2", forRemoval = true) + public GenerateResponse generate(GenerateRequest completionRequest) { + Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); + Assert.isTrue(completionRequest.stream() == false, "Stream mode must be disabled."); + + return this.restClient.post() + .uri("/api/generate") + .body(completionRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(GenerateResponse.class); + } + // -------------------------------------------------------------------------- // Generate & Streaming Generate // -------------------------------------------------------------------------- + + /** + * Generate a streaming completion for the given prompt. + * @param completionRequest Completion request. The request must set the stream + * property to true. + * @return Completion response as a {@link Flux} stream. + * @deprecated Use {@link #streamingChat(ChatRequest)} instead. + */ + @Deprecated(since = "1.0.0-M2", forRemoval = true) + public Flux generateStreaming(GenerateRequest completionRequest) { + Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); + Assert.isTrue(completionRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/api/generate") + .body(Mono.just(completionRequest), GenerateRequest.class) + .retrieve() + .bodyToFlux(GenerateResponse.class) + .handle((data, sink) -> { + if (logger.isTraceEnabled()) { + logger.trace(data); + } + sink.next(data); + }); + } + + /** + * Generate the next message in a chat with a provided model. + * This is a streaming endpoint (controlled by the 'stream' request property), so + * there will be a series of responses. The final response object will include + * statistics and additional data from the request. + * @param chatRequest Chat request. + * @return Chat response. + */ + public ChatResponse chat(ChatRequest chatRequest) { + Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); + Assert.isTrue(!chatRequest.stream(), "Stream mode must be disabled."); + + return this.restClient.post() + .uri("/api/chat") + .body(chatRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ChatResponse.class); + } + + /** + * Streaming response for the chat completion request. + * @param chatRequest Chat request. The request must set the stream property to true. + * @return Chat response as a {@link Flux} stream. + */ + public Flux streamingChat(ChatRequest chatRequest) { + Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/api/chat") + .body(Mono.just(chatRequest), GenerateRequest.class) + .retrieve() + .bodyToFlux(ChatResponse.class) + .handle((data, sink) -> { + if (logger.isTraceEnabled()) { + logger.trace(data); + } + sink.next(data); + }); + } + + /** + * Generate embeddings from a model. + * @param embeddingsRequest Embedding request. + * @return Embeddings response. + */ + public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) { + Assert.notNull(embeddingsRequest, REQUEST_BODY_NULL_ERROR); + + return this.restClient.post() + .uri("/api/embed") + .body(embeddingsRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(EmbeddingsResponse.class); + } + + // -------------------------------------------------------------------------- + // Chat & Streaming Chat + // -------------------------------------------------------------------------- + + /** + * Generate embeddings from a model. + * @param embeddingRequest Embedding request. + * @return Embedding response. + * @deprecated Use {@link #embed(EmbeddingsRequest)} instead. + */ + @Deprecated(since = "1.0.0-M2", forRemoval = true) + public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) { + Assert.notNull(embeddingRequest, REQUEST_BODY_NULL_ERROR); + + return this.restClient.post() + .uri("/api/embeddings") + .body(embeddingRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(EmbeddingResponse.class); + } + + /** + * List models that are available locally on the machine where Ollama is running. + */ + public ListModelResponse listModels() { + return this.restClient.get() + .uri("/api/tags") + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ListModelResponse.class); + } + + /** + * Show information about a model available locally on the machine where Ollama is running. + */ + public ShowModelResponse showModel(ShowModelRequest showModelRequest) { + Assert.notNull(showModelRequest, "showModelRequest must not be null"); + return this.restClient.post() + .uri("/api/show") + .body(showModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ShowModelResponse.class); + } + + /** + * Copy a model. Creates a model with another name from an existing model. + */ + public ResponseEntity copyModel(CopyModelRequest copyModelRequest) { + Assert.notNull(copyModelRequest, "copyModelRequest must not be null"); + return this.restClient.post() + .uri("/api/copy") + .body(copyModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .toBodilessEntity(); + } + + /** + * Delete a model and its data. + */ + public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) { + Assert.notNull(deleteModelRequest, "deleteModelRequest must not be null"); + return this.restClient.method(HttpMethod.DELETE) + .uri("/api/delete") + .body(deleteModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .toBodilessEntity(); + } + + // -------------------------------------------------------------------------- + // Embeddings + // -------------------------------------------------------------------------- + + /** + * Download a model from the Ollama library. Cancelled pulls are resumed from where they left off, + * and multiple calls will share the same download progress. + */ + public Flux pullModel(PullModelRequest pullModelRequest) { + Assert.notNull(pullModelRequest, "pullModelRequest must not be null"); + Assert.isTrue(pullModelRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/api/pull") + .bodyValue(pullModelRequest) + .retrieve() + .bodyToFlux(ProgressResponse.class); + } + + private static class OllamaResponseErrorHandler implements ResponseErrorHandler { + + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + int statusCode = response.getStatusCode().value(); + String statusText = response.getStatusText(); + String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8); + logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message)); + throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message)); + } + } + + } + /** * The request object sent to the /generate endpoint. * @@ -197,8 +391,10 @@ public static Builder builder(String prompt) { public static class Builder { - private String model; private final String prompt; + + private String model; + private String format; private Map options; private String system; @@ -269,7 +465,7 @@ public Builder withKeepAlive(String keepAlive) { } public GenerateRequest build() { - return new GenerateRequest(model, prompt, format, options, system, template, context, stream, raw, images, keepAlive); + return new GenerateRequest(this.model, this.prompt, this.format, this.options, this.system, this.template, this.context, this.stream, this.raw, this.images, this.keepAlive); } } @@ -312,53 +508,6 @@ public record GenerateResponse( @JsonProperty("eval_duration") Duration evalDuration) { } - /** - * Generate a completion for the given prompt. - * @param completionRequest Completion request. - * @return Completion response. - * @deprecated Use {@link #chat(ChatRequest)} instead. - */ - @Deprecated(since = "1.0.0-M2", forRemoval = true) - public GenerateResponse generate(GenerateRequest completionRequest) { - Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(completionRequest.stream() == false, "Stream mode must be disabled."); - - return this.restClient.post() - .uri("/api/generate") - .body(completionRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(GenerateResponse.class); - } - - /** - * Generate a streaming completion for the given prompt. - * @param completionRequest Completion request. The request must set the stream - * property to true. - * @return Completion response as a {@link Flux} stream. - * @deprecated Use {@link #streamingChat(ChatRequest)} instead. - */ - @Deprecated(since = "1.0.0-M2", forRemoval = true) - public Flux generateStreaming(GenerateRequest completionRequest) { - Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(completionRequest.stream(), "Request must set the stream property to true."); - - return webClient.post() - .uri("/api/generate") - .body(Mono.just(completionRequest), GenerateRequest.class) - .retrieve() - .bodyToFlux(GenerateResponse.class) - .handle((data, sink) -> { - if (logger.isTraceEnabled()) { - logger.trace(data); - } - sink.next(data); - }); - } - - // -------------------------------------------------------------------------- - // Chat & Streaming Chat - // -------------------------------------------------------------------------- /** * Chat message object. * @@ -374,6 +523,10 @@ public record Message( @JsonProperty("images") List images, @JsonProperty("tool_calls") List toolCalls) { + public static Builder builder(Role role) { + return new Builder(role); + } + /** * The role of the message in the conversation. */ @@ -420,10 +573,6 @@ public record ToolCallFunction( @JsonProperty("arguments") Map arguments) { } - public static Builder builder(Role role) { - return new Builder(role); - } - public static class Builder { private final Role role; @@ -451,7 +600,7 @@ public Builder withToolCalls(List toolCalls) { } public Message build() { - return new Message(role, content, images, toolCalls); + return new Message(this.role, this.content, this.images, this.toolCalls); } } @@ -486,6 +635,10 @@ public record ChatRequest( @JsonProperty("options") Map options ) { + public static Builder builder(String model) { + return new Builder(model); + } + /** * Represents a tool the model may call. Currently, only functions are supported as a tool. * @@ -543,10 +696,6 @@ public Function(String description, String name, String jsonSchema) { } } } - - public static Builder builder(String model) { - return new Builder(model); - } public static class Builder { @@ -602,11 +751,15 @@ public Builder withOptions(OllamaOptions options) { } public ChatRequest build() { - return new ChatRequest(model, messages, stream, format, keepAlive, tools, options); + return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options); } } } + // -------------------------------------------------------------------------- + // Models + // -------------------------------------------------------------------------- + /** * Ollama chat response object. * @@ -647,51 +800,6 @@ public record ChatResponse( ) { } - /** - * Generate the next message in a chat with a provided model. - * This is a streaming endpoint (controlled by the 'stream' request property), so - * there will be a series of responses. The final response object will include - * statistics and additional data from the request. - * @param chatRequest Chat request. - * @return Chat response. - */ - public ChatResponse chat(ChatRequest chatRequest) { - Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(!chatRequest.stream(), "Stream mode must be disabled."); - - return this.restClient.post() - .uri("/api/chat") - .body(chatRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(ChatResponse.class); - } - - /** - * Streaming response for the chat completion request. - * @param chatRequest Chat request. The request must set the stream property to true. - * @return Chat response as a {@link Flux} stream. - */ - public Flux streamingChat(ChatRequest chatRequest) { - Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - return webClient.post() - .uri("/api/chat") - .body(Mono.just(chatRequest), GenerateRequest.class) - .retrieve() - .bodyToFlux(ChatResponse.class) - .handle((data, sink) -> { - if (logger.isTraceEnabled()) { - logger.trace(data); - } - sink.next(data); - }); - } - - // -------------------------------------------------------------------------- - // Embeddings - // -------------------------------------------------------------------------- /** * Generate embeddings from a model. * @@ -718,7 +826,7 @@ public record EmbeddingsRequest( public EmbeddingsRequest(String model, String input) { this(model, List.of(input), null, null, null); } - } + } /** * Generate embeddings from a model. @@ -759,11 +867,10 @@ public record EmbeddingResponse( @JsonProperty("embedding") List embedding) { } - /** * The response object returned from the /embedding endpoint. * @param model The model used for generating the embeddings. - * @param embeddings The list of embeddings generated from the model. + * @param embeddings The list of embeddings generated from the model. * Each embedding (list of doubles) corresponds to a single input text. */ @JsonInclude(Include.NON_NULL) @@ -773,45 +880,8 @@ public record EmbeddingsResponse( @JsonProperty("total_duration") Long totalDuration, @JsonProperty("load_duration") Long loadDuration, @JsonProperty("prompt_eval_count") Integer promptEvalCount) { - - } - /** - * Generate embeddings from a model. - * @param embeddingsRequest Embedding request. - * @return Embeddings response. - */ - public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) { - Assert.notNull(embeddingsRequest, REQUEST_BODY_NULL_ERROR); - - return this.restClient.post() - .uri("/api/embed") - .body(embeddingsRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(EmbeddingsResponse.class); } - /** - * Generate embeddings from a model. - * @param embeddingRequest Embedding request. - * @return Embedding response. - * @deprecated Use {@link #embed(EmbeddingsRequest)} instead. - */ - @Deprecated(since = "1.0.0-M2", forRemoval = true) - public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) { - Assert.notNull(embeddingRequest, REQUEST_BODY_NULL_ERROR); - - return this.restClient.post() - .uri("/api/embeddings") - .body(embeddingRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(EmbeddingResponse.class); - } - - // -------------------------------------------------------------------------- - // Models - // -------------------------------------------------------------------------- @JsonInclude(Include.NON_NULL) public record Model( @@ -838,17 +908,6 @@ public record ListModelResponse( @JsonProperty("models") List models ) {} - /** - * List models that are available locally on the machine where Ollama is running. - */ - public ListModelResponse listModels() { - return this.restClient.get() - .uri("/api/tags") - .retrieve() - .onStatus(this.responseErrorHandler) - .body(ListModelResponse.class); - } - @JsonInclude(Include.NON_NULL) public record ShowModelRequest( @JsonProperty("model") String model, @@ -875,56 +934,17 @@ public record ShowModelResponse( @JsonProperty("modified_at") Instant modifiedAt ) {} - /** - * Show information about a model available locally on the machine where Ollama is running. - */ - public ShowModelResponse showModel(ShowModelRequest showModelRequest) { - Assert.notNull(showModelRequest, "showModelRequest must not be null"); - return this.restClient.post() - .uri("/api/show") - .body(showModelRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(ShowModelResponse.class); - } - @JsonInclude(Include.NON_NULL) public record CopyModelRequest( @JsonProperty("source") String source, @JsonProperty("destination") String destination ) {} - /** - * Copy a model. Creates a model with another name from an existing model. - */ - public ResponseEntity copyModel(CopyModelRequest copyModelRequest) { - Assert.notNull(copyModelRequest, "copyModelRequest must not be null"); - return this.restClient.post() - .uri("/api/copy") - .body(copyModelRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .toBodilessEntity(); - } - @JsonInclude(Include.NON_NULL) public record DeleteModelRequest( @JsonProperty("model") String model ) {} - /** - * Delete a model and its data. - */ - public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) { - Assert.notNull(deleteModelRequest, "deleteModelRequest must not be null"); - return this.restClient.method(HttpMethod.DELETE) - .uri("/api/delete") - .body(deleteModelRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .toBodilessEntity(); - } - @JsonInclude(Include.NON_NULL) public record PullModelRequest( @JsonProperty("model") String model, @@ -953,20 +973,5 @@ public record ProgressResponse( @JsonProperty("completed") Long completed ) {} - /** - * Download a model from the Ollama library. Cancelled pulls are resumed from where they left off, - * and multiple calls will share the same download progress. - */ - public Flux pullModel(PullModelRequest pullModelRequest) { - Assert.notNull(pullModelRequest, "pullModelRequest must not be null"); - Assert.isTrue(pullModelRequest.stream(), "Request must set the stream property to true."); - - return this.webClient.post() - .uri("/api/pull") - .bodyValue(pullModelRequest) - .retrieve() - .bodyToFlux(ProgressResponse.class); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java index a70765249b6..3419ee28332 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; import org.springframework.ai.model.ChatModelDescription; diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a0dad31a4b0..034a4b75c66 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; import java.util.ArrayList; @@ -23,6 +24,11 @@ import java.util.Set; import java.util.stream.Collectors; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.model.ModelOptionsUtils; @@ -31,11 +37,6 @@ import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * Helper class for creating strongly-typed Ollama options. * @@ -306,10 +307,70 @@ public static OllamaOptions builder() { return new OllamaOptions(); } + /** + * Helper factory method to create a new {@link OllamaOptions} instance. + * @return A new {@link OllamaOptions} instance. + */ + public static OllamaOptions create() { + return new OllamaOptions(); + } + + /** + * Filter out the non-supported fields from the options. + * @param options The options to filter. + * @return The filtered options. + */ + public static Map filterNonSupportedFields(Map options) { + return options.entrySet().stream() + .filter(e -> !NON_SUPPORTED_FIELDS.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static OllamaOptions fromOptions(OllamaOptions fromOptions) { + return new OllamaOptions() + .withModel(fromOptions.getModel()) + .withFormat(fromOptions.getFormat()) + .withKeepAlive(fromOptions.getKeepAlive()) + .withTruncate(fromOptions.getTruncate()) + .withUseNUMA(fromOptions.getUseNUMA()) + .withNumCtx(fromOptions.getNumCtx()) + .withNumBatch(fromOptions.getNumBatch()) + .withNumGPU(fromOptions.getNumGPU()) + .withMainGPU(fromOptions.getMainGPU()) + .withLowVRAM(fromOptions.getLowVRAM()) + .withF16KV(fromOptions.getF16KV()) + .withLogitsAll(fromOptions.getLogitsAll()) + .withVocabOnly(fromOptions.getVocabOnly()) + .withUseMMap(fromOptions.getUseMMap()) + .withUseMLock(fromOptions.getUseMLock()) + .withNumThread(fromOptions.getNumThread()) + .withNumKeep(fromOptions.getNumKeep()) + .withSeed(fromOptions.getSeed()) + .withNumPredict(fromOptions.getNumPredict()) + .withTopK(fromOptions.getTopK()) + .withTopP(fromOptions.getTopP()) + .withTfsZ(fromOptions.getTfsZ()) + .withTypicalP(fromOptions.getTypicalP()) + .withRepeatLastN(fromOptions.getRepeatLastN()) + .withTemperature(fromOptions.getTemperature()) + .withRepeatPenalty(fromOptions.getRepeatPenalty()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMirostat(fromOptions.getMirostat()) + .withMirostatTau(fromOptions.getMirostatTau()) + .withMirostatEta(fromOptions.getMirostatEta()) + .withPenalizeNewline(fromOptions.getPenalizeNewline()) + .withStop(fromOptions.getStop()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withToolContext(fromOptions.getToolContext()); + } + public OllamaOptions build() { return this; } - + /** * @param model The ollama model names to use. See the {@link OllamaModel} for the common models. */ @@ -510,7 +571,7 @@ public OllamaOptions withToolContext(Map toolContext) { } else { this.toolContext.putAll(toolContext); - } + } return this; } @@ -519,7 +580,7 @@ public OllamaOptions withToolContext(Map toolContext) { // ------------------- @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -811,7 +872,7 @@ public void setTruncate(Boolean truncate) { @Override public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.functionCallbacks; } @Override @@ -862,107 +923,51 @@ public Map toMap() { return ModelOptionsUtils.objectToMap(this); } - /** - * Helper factory method to create a new {@link OllamaOptions} instance. - * @return A new {@link OllamaOptions} instance. - */ - public static OllamaOptions create() { - return new OllamaOptions(); - } - - /** - * Filter out the non-supported fields from the options. - * @param options The options to filter. - * @return The filtered options. - */ - public static Map filterNonSupportedFields(Map options) { - return options.entrySet().stream() - .filter(e -> !NON_SUPPORTED_FIELDS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } - @Override public OllamaOptions copy() { return fromOptions(this); } - - public static OllamaOptions fromOptions(OllamaOptions fromOptions) { - return new OllamaOptions() - .withModel(fromOptions.getModel()) - .withFormat(fromOptions.getFormat()) - .withKeepAlive(fromOptions.getKeepAlive()) - .withTruncate(fromOptions.getTruncate()) - .withUseNUMA(fromOptions.getUseNUMA()) - .withNumCtx(fromOptions.getNumCtx()) - .withNumBatch(fromOptions.getNumBatch()) - .withNumGPU(fromOptions.getNumGPU()) - .withMainGPU(fromOptions.getMainGPU()) - .withLowVRAM(fromOptions.getLowVRAM()) - .withF16KV(fromOptions.getF16KV()) - .withLogitsAll(fromOptions.getLogitsAll()) - .withVocabOnly(fromOptions.getVocabOnly()) - .withUseMMap(fromOptions.getUseMMap()) - .withUseMLock(fromOptions.getUseMLock()) - .withNumThread(fromOptions.getNumThread()) - .withNumKeep(fromOptions.getNumKeep()) - .withSeed(fromOptions.getSeed()) - .withNumPredict(fromOptions.getNumPredict()) - .withTopK(fromOptions.getTopK()) - .withTopP(fromOptions.getTopP()) - .withTfsZ(fromOptions.getTfsZ()) - .withTypicalP(fromOptions.getTypicalP()) - .withRepeatLastN(fromOptions.getRepeatLastN()) - .withTemperature(fromOptions.getTemperature()) - .withRepeatPenalty(fromOptions.getRepeatPenalty()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withMirostat(fromOptions.getMirostat()) - .withMirostatTau(fromOptions.getMirostatTau()) - .withMirostatEta(fromOptions.getMirostatEta()) - .withPenalizeNewline(fromOptions.getPenalizeNewline()) - .withStop(fromOptions.getStop()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withToolContext(fromOptions.getToolContext()); - } // @formatter:on @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } OllamaOptions that = (OllamaOptions) o; - return Objects.equals(model, that.model) && Objects.equals(format, that.format) - && Objects.equals(keepAlive, that.keepAlive) && Objects.equals(truncate, that.truncate) - && Objects.equals(useNUMA, that.useNUMA) && Objects.equals(numCtx, that.numCtx) - && Objects.equals(numBatch, that.numBatch) && Objects.equals(numGPU, that.numGPU) - && Objects.equals(mainGPU, that.mainGPU) && Objects.equals(lowVRAM, that.lowVRAM) - && Objects.equals(f16KV, that.f16KV) && Objects.equals(logitsAll, that.logitsAll) - && Objects.equals(vocabOnly, that.vocabOnly) && Objects.equals(useMMap, that.useMMap) - && Objects.equals(useMLock, that.useMLock) && Objects.equals(numThread, that.numThread) - && Objects.equals(numKeep, that.numKeep) && Objects.equals(seed, that.seed) - && Objects.equals(numPredict, that.numPredict) && Objects.equals(topK, that.topK) - && Objects.equals(topP, that.topP) && Objects.equals(tfsZ, that.tfsZ) - && Objects.equals(typicalP, that.typicalP) && Objects.equals(repeatLastN, that.repeatLastN) - && Objects.equals(temperature, that.temperature) && Objects.equals(repeatPenalty, that.repeatPenalty) - && Objects.equals(presencePenalty, that.presencePenalty) - && Objects.equals(frequencyPenalty, that.frequencyPenalty) && Objects.equals(mirostat, that.mirostat) - && Objects.equals(mirostatTau, that.mirostatTau) && Objects.equals(mirostatEta, that.mirostatEta) - && Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop) - && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions) - && Objects.equals(toolContext, that.toolContext); + return Objects.equals(this.model, that.model) && Objects.equals(this.format, that.format) + && Objects.equals(this.keepAlive, that.keepAlive) && Objects.equals(this.truncate, that.truncate) + && Objects.equals(this.useNUMA, that.useNUMA) && Objects.equals(this.numCtx, that.numCtx) + && Objects.equals(this.numBatch, that.numBatch) && Objects.equals(this.numGPU, that.numGPU) + && Objects.equals(this.mainGPU, that.mainGPU) && Objects.equals(this.lowVRAM, that.lowVRAM) + && Objects.equals(this.f16KV, that.f16KV) && Objects.equals(this.logitsAll, that.logitsAll) + && Objects.equals(this.vocabOnly, that.vocabOnly) && Objects.equals(this.useMMap, that.useMMap) + && Objects.equals(this.useMLock, that.useMLock) && Objects.equals(this.numThread, that.numThread) + && Objects.equals(this.numKeep, that.numKeep) && Objects.equals(this.seed, that.seed) + && Objects.equals(this.numPredict, that.numPredict) && Objects.equals(this.topK, that.topK) + && Objects.equals(this.topP, that.topP) && Objects.equals(this.tfsZ, that.tfsZ) + && Objects.equals(this.typicalP, that.typicalP) && Objects.equals(this.repeatLastN, that.repeatLastN) + && Objects.equals(this.temperature, that.temperature) + && Objects.equals(this.repeatPenalty, that.repeatPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau) + && Objects.equals(this.mirostatEta, that.mirostatEta) + && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.functionCallbacks, that.functionCallbacks) + && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) + && Objects.equals(this.functions, that.functions) && Objects.equals(this.toolContext, that.toolContext); } @Override public int hashCode() { return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx, - this.numBatch, this.numGPU, this.mainGPU, lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, + this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, - this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, + this.topP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls, this.toolContext); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java index 5d600b14ed3..f850cd5793a 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.management; import java.time.Duration; @@ -66,7 +67,8 @@ public Builder withMaxRetries(Integer maxRetries) { } public ModelManagementOptions build() { - return new ModelManagementOptions(pullModelStrategy, additionalModels, timeout, maxRetries); + return new ModelManagementOptions(this.pullModelStrategy, this.additionalModels, this.timeout, + this.maxRetries); } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java index ebc736c8286..572ec5896c3 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java @@ -1,31 +1,33 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.ollama.management; +import java.time.Duration; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.util.retry.Retry; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest; import org.springframework.ai.ollama.api.OllamaApi.ListModelResponse; import org.springframework.ai.ollama.api.OllamaApi.PullModelRequest; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.util.retry.Retry; - -import java.time.Duration; /** * Manage the lifecycle of models in Ollama. @@ -57,7 +59,7 @@ public OllamaModelManager(OllamaApi ollamaApi, ModelManagementOptions options) { public boolean isModelAvailable(String modelName) { Assert.hasText(modelName, "modelName must not be empty"); - ListModelResponse listModelResponse = ollamaApi.listModels(); + ListModelResponse listModelResponse = this.ollamaApi.listModels(); if (!CollectionUtils.isEmpty(listModelResponse.models())) { var normalizedModelName = normalizeModelName(modelName); return listModelResponse.models().stream().anyMatch(m -> m.name().equals(normalizedModelName)); @@ -79,17 +81,17 @@ private String normalizeModelName(String modelName) { } public void deleteModel(String modelName) { - logger.info("Start deletion of model: {}", modelName); + this.logger.info("Start deletion of model: {}", modelName); if (!isModelAvailable(modelName)) { - logger.info("Model {} not found", modelName); + this.logger.info("Model {} not found", modelName); return; } this.ollamaApi.deleteModel(new DeleteModelRequest(modelName)); - logger.info("Completed deletion of model: {}", modelName); + this.logger.info("Completed deletion of model: {}", modelName); } public void pullModel(String modelName) { - pullModel(modelName, options.pullModelStrategy()); + pullModel(modelName, this.options.pullModelStrategy()); } public void pullModel(String modelName, PullModelStrategy pullModelStrategy) { @@ -99,27 +101,27 @@ public void pullModel(String modelName, PullModelStrategy pullModelStrategy) { if (PullModelStrategy.WHEN_MISSING.equals(pullModelStrategy)) { if (isModelAvailable(modelName)) { - logger.debug("Model '{}' already available. Skipping pull operation.", modelName); + this.logger.debug("Model '{}' already available. Skipping pull operation.", modelName); return; } } // @formatter:off - logger.info("Start pulling model: {}", modelName); + this.logger.info("Start pulling model: {}", modelName); this.ollamaApi.pullModel(new PullModelRequest(modelName)) .bufferUntilChanged(OllamaApi.ProgressResponse::status) .doOnEach(signal -> { var progressResponses = signal.get(); if (!CollectionUtils.isEmpty(progressResponses) && progressResponses.get(progressResponses.size() - 1) != null) { - logger.info("Pulling the '{}' model - Status: {}", modelName, progressResponses.get(progressResponses.size() - 1).status()); + this.logger.info("Pulling the '{}' model - Status: {}", modelName, progressResponses.get(progressResponses.size() - 1).status()); } }) .takeUntil(progressResponses -> progressResponses.get(0) != null && progressResponses.get(0).status().equals("success")) - .timeout(options.timeout()) - .retryWhen(Retry.backoff(options.maxRetries(), Duration.ofSeconds(5))) + .timeout(this.options.timeout()) + .retryWhen(Retry.backoff(this.options.maxRetries(), Duration.ofSeconds(5))) .blockLast(); - logger.info("Completed pulling the '{}' model", modelName); + this.logger.info("Completed pulling the '{}' model", modelName); // @formatter:on } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java index 11be453aaba..e6f021008a0 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.management; /** diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java index dc7eed369f4..0a76de5a37f 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java index e1c1bfac861..3ccf39b4c86 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.metadata; import java.util.Optional; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.util.Assert; @@ -30,25 +32,25 @@ public class OllamaChatUsage implements Usage { protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }"; - public static OllamaChatUsage from(OllamaApi.ChatResponse response) { - Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); - return new OllamaChatUsage(response); - } - private final OllamaApi.ChatResponse response; public OllamaChatUsage(OllamaApi.ChatResponse response) { this.response = response; } + public static OllamaChatUsage from(OllamaApi.ChatResponse response) { + Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); + return new OllamaChatUsage(response); + } + @Override public Long getPromptTokens() { - return Optional.ofNullable(response.promptEvalCount()).map(Integer::longValue).orElse(0L); + return Optional.ofNullable(this.response.promptEvalCount()).map(Integer::longValue).orElse(0L); } @Override public Long getGenerationTokens() { - return Optional.ofNullable(response.evalCount()).map(Integer::longValue).orElse(0L); + return Optional.ofNullable(this.response.evalCount()).map(Integer::longValue).orElse(0L); } @Override diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java index 61ea60b33c6..c75ebaac15a 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.metadata; import java.util.Optional; @@ -31,17 +32,17 @@ public class OllamaEmbeddingUsage implements Usage { protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }"; - public static OllamaEmbeddingUsage from(EmbeddingsResponse response) { - Assert.notNull(response, "OllamaApi.EmbeddingsResponse must not be null"); - return new OllamaEmbeddingUsage(response); - } - private Long promptTokens; public OllamaEmbeddingUsage(EmbeddingsResponse response) { this.promptTokens = Optional.ofNullable(response.promptEvalCount()).map(Integer::longValue).orElse(0L); } + public static OllamaEmbeddingUsage from(EmbeddingsResponse response) { + Assert.notNull(response, "OllamaApi.EmbeddingsResponse must not be null"); + return new OllamaEmbeddingUsage(response); + } + @Override public Long getPromptTokens() { return this.promptTokens; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java index a8845c0fdef..f58413f2640 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java @@ -2,12 +2,13 @@ import java.time.Duration; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.util.StringUtils; -import org.testcontainers.ollama.OllamaContainer; public class BaseOllamaIT { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index ae9d51fac8e..5738c337f8b 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -29,19 +37,12 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.tool.MockWeatherService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -74,7 +75,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -99,7 +100,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -132,4 +133,4 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index 4b2fac29e55..88db3b96c3a 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -33,20 +40,15 @@ import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.core.convert.support.DefaultConversionService; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -67,10 +69,10 @@ class OllamaChatModelIT extends BaseOllamaIT { @Test void autoPullModelTest() { - var modelManager = new OllamaModelManager(ollamaApi); + var modelManager = new OllamaModelManager(this.ollamaApi); assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue(); - String joke = ChatClient.create(chatModel) + String joke = ChatClient.create(this.chatModel) .prompt("Tell me a joke") .options(OllamaOptions.builder().withModel(ADDITIONAL_MODEL).build()) .call() @@ -97,13 +99,13 @@ void roleTest() { Prompt prompt = new Prompt(List.of(systemMessage, userMessage), portableOptions); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); // ollama specific options var ollamaOptions = new OllamaOptions().withLowVRAM(true); - response = chatModel.call(new Prompt(List.of(systemMessage, userMessage), ollamaOptions)); + response = this.chatModel.call(new Prompt(List.of(systemMessage, userMessage), ollamaOptions)); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -121,12 +123,12 @@ void testMessageHistory() { Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Hello"), response.getResult().getOutput(), new UserMessage("Tell me just the names of those pirates."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); } @@ -134,7 +136,7 @@ void testMessageHistory() { @Test void usageTest() { Prompt prompt = new Prompt("Tell me a joke"); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); Usage usage = response.getMetadata().getUsage(); assertThat(usage).isNotNull(); @@ -175,7 +177,7 @@ void mapOutputConvert() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result).isNotNull(); @@ -184,9 +186,6 @@ void mapOutputConvert() { assertThat((String) result.get("B")).containsIgnoringCase("blue"); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); @@ -198,7 +197,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -217,7 +216,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -233,6 +232,10 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { @@ -255,4 +258,4 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 5d8956552c4..0cb22784c05 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -31,9 +35,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.ClassPathResource; import org.springframework.util.MimeTypeUtils; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertThrows; @@ -57,7 +58,7 @@ void unsupportedMediaType() { var userMessage = new UserMessage("Explain what do you see in this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt(List.of(userMessage)))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt(List.of(userMessage)))); } @Test @@ -67,7 +68,7 @@ void multiModalityTest() { var userMessage = new UserMessage("Explain what do you see in this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel.call(new Prompt(List.of(userMessage))); + var response = this.chatModel.call(new Prompt(List.of(userMessage))); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple"); @@ -91,4 +92,4 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java index 4e9f803132c..5254f3e018f 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -33,10 +39,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class OllamaChatModelObservationIT extends BaseOllamaIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -79,7 +81,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -103,7 +105,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -124,7 +126,7 @@ void observationForStreamingChatOperation() { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index cdb829953be..5e8e74107c7 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; @@ -39,7 +41,7 @@ public class OllamaChatRequestTests { @Test public void createRequestWithDefaultOptions() { - var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), false); + var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content"), false); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isFalse(); @@ -57,7 +59,7 @@ public void createRequestWithPromptOllamaOptions() { // Runtime options should override the default options. OllamaOptions promptOptions = new OllamaOptions().withTemperature(0.8).withTopP(0.5).withNumGPU(2); - var request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); + var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isTrue(); @@ -65,11 +67,11 @@ public void createRequestWithPromptOllamaOptions() { assertThat(request.model()).isEqualTo("MODEL_NAME"); assertThat(request.options().get("temperature")).isEqualTo(0.8); assertThat(request.options().get("top_k")).isEqualTo(99); // still the default - // value. + // value. assertThat(request.options().get("num_gpu")).isEqualTo(2); assertThat(request.options().get("top_p")).isEqualTo(0.5); // new field introduced - // by the - // promptOptions. + // by the + // promptOptions. } @Test @@ -82,7 +84,7 @@ public void createRequestWithPromptPortableChatOptions() { .withTopP(0.6) .build(); - var request = chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true); + var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isTrue(); @@ -100,7 +102,7 @@ public void createRequestWithPromptOptionsModelOverride() { // Ollama runtime options. OllamaOptions promptOptions = new OllamaOptions().withModel("PROMPT_MODEL"); - var request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); + var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); assertThat(request.model()).isEqualTo("PROMPT_MODEL"); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java index ae0612ac339..00322204e9f 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,25 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -52,8 +54,8 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT { @Test void embeddings() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest( + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest( List.of("Hello World", "Something else"), OllamaOptions.builder().withTruncate(false).build())); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); @@ -64,18 +66,18 @@ void embeddings() { assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @Test void autoPullModelAtStartupTime() { var model = "all-minilm"; - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - var modelManager = new OllamaModelManager(ollamaApi); + var modelManager = new OllamaModelManager(this.ollamaApi); assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue(); - EmbeddingResponse embeddingResponse = embeddingModel + EmbeddingResponse embeddingResponse = this.embeddingModel .call(new EmbeddingRequest(List.of("Hello World", "Something else"), OllamaOptions.builder().withModel(model).withTruncate(false).build())); @@ -88,7 +90,7 @@ void autoPullModelAtStartupTime() { assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); modelManager.deleteModel(ADDITIONAL_MODEL); } @@ -115,4 +117,4 @@ public OllamaEmbeddingModel ollamaEmbedding(OllamaApi ollamaApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java index aaf786ff24e..ad3ebc5a72b 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; - import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; @@ -36,8 +39,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -63,13 +64,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java index b77be0a6049..0afb8c24755 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java @@ -1,26 +1,32 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.ollama; +import java.time.Duration; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -30,10 +36,6 @@ import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; import org.springframework.ai.ollama.api.OllamaOptions; -import java.time.Duration; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.when; @@ -54,7 +56,7 @@ public class OllamaEmbeddingModelTests { @Test public void options() { - when(ollamaApi.embed(embeddingsRequestCaptor.capture())) + when(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture())) .thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(new float[] { 1f, 2f, 3f }, new float[] { 4f, 5f, 6f }), 0L, 0L, 0)) .thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2", @@ -64,7 +66,7 @@ public void options() { var defaultOptions = OllamaOptions.builder().withModel("DEFAULT_MODEL").build(); var embeddingModel = OllamaEmbeddingModel.builder() - .withOllamaApi(ollamaApi) + .withOllamaApi(this.ollamaApi) .withDefaultOptions(defaultOptions) .build(); @@ -80,11 +82,11 @@ public void options() { assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME"); - assertThat(embeddingsRequestCaptor.getValue().keepAlive()).isNull(); - assertThat(embeddingsRequestCaptor.getValue().truncate()).isNull(); - assertThat(embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input1", "Input2", "Input3")); - assertThat(embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of()); - assertThat(embeddingsRequestCaptor.getValue().model()).isEqualTo("DEFAULT_MODEL"); + assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isNull(); + assertThat(this.embeddingsRequestCaptor.getValue().truncate()).isNull(); + assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input1", "Input2", "Input3")); + assertThat(this.embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of()); + assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("DEFAULT_MODEL"); // Tests runtime options var runtimeOptions = OllamaOptions.builder() @@ -105,11 +107,11 @@ public void options() { assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME2"); - assertThat(embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofMinutes(10)); - assertThat(embeddingsRequestCaptor.getValue().truncate()).isFalse(); - assertThat(embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input4", "Input5", "Input6")); - assertThat(embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of("main_gpu", 666)); - assertThat(embeddingsRequestCaptor.getValue().model()).isEqualTo("RUNTIME_MODEL"); + assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofMinutes(10)); + assertThat(this.embeddingsRequestCaptor.getValue().truncate()).isFalse(); + assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input4", "Input5", "Input6")); + assertThat(this.embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of("main_gpu", 666)); + assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("RUNTIME_MODEL"); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java index dfa8c9f22d4..309ebc2eb67 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -38,7 +40,7 @@ public class OllamaEmbeddingRequestTests { @Test public void ollamaEmbeddingRequestDefaultOptions() { - var request = embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), null); + var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), null); assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); assertThat(request.options().get("num_gpu")).isEqualTo(1); @@ -56,7 +58,7 @@ public void ollamaEmbeddingRequestRequestOptions() { .withUseMMap(true)// .withNumGPU(2); - var request = embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions); + var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions); assertThat(request.model()).isEqualTo("PROMPT_MODEL"); assertThat(request.options().get("num_gpu")).isEqualTo(2); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 8a13c29b5ca..1e2bf625fc5 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; import org.testcontainers.utility.DockerImageName; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java index 4b7e3f49d2b..3b030e8c6f6 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index f9ed57882f3..cbf53fb3e47 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.ollama.BaseOllamaIT; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.ChatResponse; @@ -27,12 +35,6 @@ import org.springframework.ai.ollama.api.OllamaApi.GenerateResponse; import org.springframework.ai.ollama.api.OllamaApi.Message; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; - -import java.io.IOException; -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -137,4 +139,4 @@ public void embedText() { assertThat(response.totalDuration()).isGreaterThan(1); } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java index bc4e878ab03..f565257010a 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.ollama.api; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.ollama.api; import java.io.IOException; import java.time.Duration; @@ -23,9 +22,12 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ollama.BaseOllamaIT; import org.springframework.http.HttpStatus; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for the Ollama APIs to manage models. @@ -98,4 +100,4 @@ public void pullModel() { assertThat(listModelResponse.models().stream().anyMatch(model -> model.name().contains(MODEL))).isTrue(); } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java index faffdf24235..205df66fe9b 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.ollama.api; -import org.junit.jupiter.api.Test; +package org.springframework.ai.ollama.api; import java.util.List; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java index 64cb56fd62a..c732a8e5ed1 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,29 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api.tool; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Christian Tzolov */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -63,28 +71,23 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java index dab9bd1799b..81d1f56b5bd 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/OllamaApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,11 +16,18 @@ package org.springframework.ai.ollama.api.tool; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.ollama.BaseOllamaIT; import org.springframework.ai.ollama.api.OllamaApi; @@ -28,13 +35,6 @@ import org.springframework.ai.ollama.api.OllamaApi.Message; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall; -import org.springframework.ai.ollama.api.OllamaModel; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -50,10 +50,10 @@ public class OllamaApiToolFunctionCallIT extends BaseOllamaIT { private static final Logger logger = LoggerFactory.getLogger(OllamaApiToolFunctionCallIT.class); - MockWeatherService weatherService = new MockWeatherService(); - static OllamaApi ollamaApi; + MockWeatherService weatherService = new MockWeatherService(); + @BeforeAll public static void beforeAll() throws IOException, InterruptedException { ollamaApi = buildOllamaApiWithModel(MODEL); @@ -117,7 +117,7 @@ public void toolFunctionCall() { MockWeatherService.Request weatherRequest = ModelOptionsUtils.mapToClass(responseMap, MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(Message.builder(Role.TOOL) @@ -140,4 +140,4 @@ public void toolFunctionCall() { assertThat(chatCompletion2.message().content()).contains("Paris").contains("15"); } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/management/OllamaModelManagerIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/management/OllamaModelManagerIT.java index ab99833ca41..2d640585b10 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/management/OllamaModelManagerIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/management/OllamaModelManagerIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.management; +import java.io.IOException; +import java.time.Duration; +import java.util.List; + import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; -import org.springframework.ai.ollama.BaseOllamaIT; -import org.springframework.ai.ollama.api.OllamaModel; import org.testcontainers.junit.jupiter.Testcontainers; -import java.io.IOException; -import java.time.Duration; -import java.util.List; +import org.springframework.ai.ollama.BaseOllamaIT; +import org.springframework.ai.ollama.api.OllamaModel; import static org.assertj.core.api.Assertions.assertThat; diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml index 1019f915a55..b059ccff986 100644 --- a/models/spring-ai-openai/pom.xml +++ b/models/spring-ai-openai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java index 3ec1ad510c0..5e1e1619a76 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai; public interface ImageResponseMetadata { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java index 13057cb1a32..73813b38306 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,6 +19,8 @@ import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat; @@ -33,7 +35,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import reactor.core.publisher.Flux; /** * OpenAI audio speech client implementation for backed by {@link OpenAiAudioApi}. @@ -46,6 +47,12 @@ */ public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel { + /** + * The speed of the default voice synthesis. + * @see OpenAiAudioSpeechOptions + */ + private static final Float SPEED = 1.0f; + private final Logger logger = LoggerFactory.getLogger(getClass()); /** @@ -53,12 +60,6 @@ public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel */ private final OpenAiAudioSpeechOptions defaultOptions; - /** - * The speed of the default voice synthesis. - * @see OpenAiAudioSpeechOptions - */ - private static final Float SPEED = 1.0f; - /** * The retry template used to retry the OpenAI Audio API calls. */ @@ -131,7 +132,7 @@ public SpeechResponse call(SpeechPrompt speechPrompt) { var speech = speechEntity.getBody(); if (speech == null) { - logger.warn("No speech response returned for speechRequest: {}", speechRequest); + this.logger.warn("No speech response returned for speechRequest: {}", speechRequest); return new SpeechResponse(new Speech(new byte[0])); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java index 8d6ca7c9de6..47cfc153cdd 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +18,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.model.ModelOptions; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice; @@ -70,137 +71,150 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final OpenAiAudioSpeechOptions options = new OpenAiAudioSpeechOptions(); - - public Builder withModel(String model) { - options.model = model; - return this; - } - - public Builder withInput(String input) { - options.input = input; - return this; - } - - public Builder withVoice(Voice voice) { - options.voice = voice; - return this; - } - - public Builder withResponseFormat(AudioResponseFormat responseFormat) { - options.responseFormat = responseFormat; - return this; - } - - public Builder withSpeed(Float speed) { - options.speed = speed; - return this; - } - - public OpenAiAudioSpeechOptions build() { - return options; - } - - } - public String getModel() { - return model; - } - - public String getInput() { - return input; - } - - public Voice getVoice() { - return voice; - } - - public AudioResponseFormat getResponseFormat() { - return responseFormat; - } - - public Float getSpeed() { - return speed; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((input == null) ? 0 : input.hashCode()); - result = prime * result + ((voice == null) ? 0 : voice.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((speed == null) ? 0 : speed.hashCode()); - return result; + return this.model; } public void setModel(String model) { this.model = model; } + public String getInput() { + return this.input; + } + public void setInput(String input) { this.input = input; } + public Voice getVoice() { + return this.voice; + } + public void setVoice(Voice voice) { this.voice = voice; } + public AudioResponseFormat getResponseFormat() { + return this.responseFormat; + } + public void setResponseFormat(AudioResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public Float getSpeed() { + return this.speed; + } + public void setSpeed(Float speed) { this.speed = speed; } + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.input == null) ? 0 : this.input.hashCode()); + result = prime * result + ((this.voice == null) ? 0 : this.voice.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + result = prime * result + ((this.speed == null) ? 0 : this.speed.hashCode()); + return result; + } + @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } OpenAiAudioSpeechOptions other = (OpenAiAudioSpeechOptions) obj; - if (model == null) { - if (other.model != null) + if (this.model == null) { + if (other.model != null) { return false; + } } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; - if (input == null) { - if (other.input != null) + } + if (this.input == null) { + if (other.input != null) { return false; + } } - else if (!input.equals(other.input)) + else if (!this.input.equals(other.input)) { return false; - if (voice == null) { - if (other.voice != null) + } + if (this.voice == null) { + if (other.voice != null) { return false; + } } - else if (!voice.equals(other.voice)) + else if (!this.voice.equals(other.voice)) { return false; - if (responseFormat == null) { - if (other.responseFormat != null) + } + if (this.responseFormat == null) { + if (other.responseFormat != null) { return false; + } } - else if (!responseFormat.equals(other.responseFormat)) + else if (!this.responseFormat.equals(other.responseFormat)) { return false; - if (speed == null) { + } + if (this.speed == null) { return other.speed == null; } - else - return speed.equals(other.speed); + else { + return this.speed.equals(other.speed); + } } @Override public String toString() { - return "OpenAiAudioSpeechOptions{" + "model='" + model + '\'' + ", input='" + input + '\'' + ", voice='" + voice - + '\'' + ", responseFormat='" + responseFormat + '\'' + ", speed=" + speed + '}'; + return "OpenAiAudioSpeechOptions{" + "model='" + this.model + '\'' + ", input='" + this.input + '\'' + + ", voice='" + this.voice + '\'' + ", responseFormat='" + this.responseFormat + '\'' + ", speed=" + + this.speed + '}'; + } + + public static class Builder { + + private final OpenAiAudioSpeechOptions options = new OpenAiAudioSpeechOptions(); + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withInput(String input) { + this.options.input = input; + return this; + } + + public Builder withVoice(Voice voice) { + this.options.voice = voice; + return this; + } + + public Builder withResponseFormat(AudioResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withSpeed(Float speed) { + this.options.speed = speed; + return this; + } + + public OpenAiAudioSpeechOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java index fbf51bb78ed..8aa43728b43 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,34 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/* -* Copyright 2024-2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ package org.springframework.ai.openai; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.audio.transcription.AudioTranscription; +import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; +import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.model.Model; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse; -import org.springframework.ai.audio.transcription.AudioTranscription; -import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; -import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; @@ -133,7 +118,7 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPro var transcription = transcriptionEntity.getBody(); if (transcription == null) { - logger.warn("No transcription returned for request: {}", audioResource); + this.logger.warn("No transcription returned for request: {}", audioResource); return new AudioTranscriptionResponse(null); } @@ -154,7 +139,7 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPro var transcription = transcriptionEntity.getBody(); if (transcription == null) { - logger.warn("No transcription returned for request: {}", audioResource); + this.logger.warn("No transcription returned for request: {}", audioResource); return new AudioTranscriptionResponse(null); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptions.java index 8808d48bada..5a2045f9bba 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.audio.transcription.AudioTranscriptionOptions; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest.GranularityType; @@ -58,54 +60,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected OpenAiAudioTranscriptionOptions options; - - public Builder() { - this.options = new OpenAiAudioTranscriptionOptions(); - } - - public Builder(OpenAiAudioTranscriptionOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withLanguage(String language) { - this.options.language = language; - return this; - } - - public Builder withPrompt(String prompt) { - this.options.prompt = prompt; - return this; - } - - public Builder withResponseFormat(TranscriptResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withTemperature(Float temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withGranularityType(GranularityType granularityType) { - this.options.granularityType = granularityType; - return this; - } - - public OpenAiAudioTranscriptionOptions build() { - return this.options; - } - - } - @Override public String getModel() { return this.model; @@ -139,7 +93,6 @@ public void setTemperature(Float temperature) { this.temperature = temperature; } - public TranscriptResponseFormat getResponseFormat() { return this.responseFormat; } @@ -160,10 +113,10 @@ public void setGranularityType(GranularityType granularityType) { public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((prompt == null) ? 0 : prompt.hashCode()); - result = prime * result + ((language == null) ? 0 : language.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.prompt == null) ? 0 : this.prompt.hashCode()); + result = prime * result + ((this.language == null) ? 0 : this.language.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); return result; } @@ -180,7 +133,7 @@ public boolean equals(Object obj) { if (other.model != null) return false; } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) return false; if (this.prompt == null) { if (other.prompt != null) @@ -202,4 +155,52 @@ else if (!this.responseFormat.equals(other.responseFormat)) return false; return true; } + + public static class Builder { + + protected OpenAiAudioTranscriptionOptions options; + + public Builder() { + this.options = new OpenAiAudioTranscriptionOptions(); + } + + public Builder(OpenAiAudioTranscriptionOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withLanguage(String language) { + this.options.language = language; + return this; + } + + public Builder withPrompt(String prompt) { + this.options.prompt = prompt; + return this; + } + + public Builder withResponseFormat(TranscriptResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withGranularityType(GranularityType granularityType) { + this.options.granularityType = granularityType; + return this; + } + + public OpenAiAudioTranscriptionOptions build() { + return this.options; + } + + } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 8d6b60b23d0..b60f1889956 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import java.util.ArrayList; @@ -25,8 +26,14 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -72,12 +79,6 @@ import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} * backed by {@link OpenAiApi}. @@ -555,7 +556,7 @@ public ChatOptions getDefaultOptions() { @Override public String toString() { - return "OpenAiChatModel [defaultOptions=" + defaultOptions + "]"; + return "OpenAiChatModel [defaultOptions=" + this.defaultOptions + "]"; } /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 3bc8c04ff9e..b79b009d9e0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import java.util.ArrayList; @@ -23,6 +24,11 @@ import java.util.Objects; import java.util.Set; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; @@ -35,11 +41,6 @@ import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * @author Christian Tzolov * @author Mariusz Bernacki @@ -202,159 +203,33 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected OpenAiChatOptions options; - - public Builder() { - this.options = new OpenAiChatOptions(); - } - - public Builder(OpenAiChatOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withModel(OpenAiApi.ChatModel openAiChatModel) { - this.options.model = openAiChatModel.getName(); - return this; - } - - public Builder withFrequencyPenalty(Double frequencyPenalty) { - this.options.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder withLogitBias(Map logitBias) { - this.options.logitBias = logitBias; - return this; - } - - public Builder withLogprobs(Boolean logprobs) { - this.options.logprobs = logprobs; - return this; - } - - public Builder withTopLogprobs(Integer topLogprobs) { - this.options.topLogprobs = topLogprobs; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withMaxCompletionTokens(Integer maxCompletionTokens) { - this.options.maxCompletionTokens = maxCompletionTokens; - return this; - } - - public Builder withN(Integer n) { - this.options.n = n; - return this; - } - - public Builder withPresencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } - - public Builder withResponseFormat(ResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withStreamUsage(boolean enableStreamUsage) { - this.options.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null; - return this; - } - - public Builder withSeed(Integer seed) { - this.options.seed = seed; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withTools(List tools) { - this.options.tools = tools; - return this; - } - - public Builder withToolChoice(String toolChoice) { - this.options.toolChoice = toolChoice; - return this; - } - - public Builder withUser(String user) { - this.options.user = user; - return this; - } - - public Builder withParallelToolCalls(Boolean parallelToolCalls) { - this.options.parallelToolCalls = parallelToolCalls; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withHttpHeaders(Map httpHeaders) { - this.options.httpHeaders = httpHeaders; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public OpenAiChatOptions build() { - return this.options; - } - + public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { + return OpenAiChatOptions.builder() + .withModel(fromOptions.getModel()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withLogitBias(fromOptions.getLogitBias()) + .withLogprobs(fromOptions.getLogprobs()) + .withTopLogprobs(fromOptions.getTopLogprobs()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withMaxCompletionTokens(fromOptions.getMaxCompletionTokens()) + .withN(fromOptions.getN()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStreamUsage(fromOptions.getStreamUsage()) + .withSeed(fromOptions.getSeed()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withUser(fromOptions.getUser()) + .withParallelToolCalls(fromOptions.getParallelToolCalls()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withHttpHeaders(fromOptions.getHttpHeaders()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); } public Boolean getStreamUsage() { @@ -417,7 +292,7 @@ public void setMaxTokens(Integer maxTokens) { } public Integer getMaxCompletionTokens() { - return maxCompletionTokens; + return this.maxCompletionTokens; } public void setMaxCompletionTokens(Integer maxCompletionTokens) { @@ -450,7 +325,7 @@ public void setResponseFormat(ResponseFormat responseFormat) { } public StreamOptions getStreamOptions() { - return streamOptions; + return this.streamOptions; } public void setStreamOptions(StreamOptions streamOptions) { @@ -514,6 +389,10 @@ public String getToolChoice() { return this.toolChoice; } + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + @Override public Boolean getProxyToolCalls() { return this.proxyToolCalls; @@ -523,10 +402,6 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } - public void setToolChoice(String toolChoice) { - this.toolChoice = toolChoice; - } - public String getUser() { return this.user; } @@ -555,7 +430,7 @@ public void setFunctionCallbacks(List functionCallbacks) { @Override public Set getFunctions() { - return functions; + return this.functions; } public void setFunctions(Set functionNames) { @@ -591,35 +466,6 @@ public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); } - public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { - return OpenAiChatOptions.builder() - .withModel(fromOptions.getModel()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withLogitBias(fromOptions.getLogitBias()) - .withLogprobs(fromOptions.getLogprobs()) - .withTopLogprobs(fromOptions.getTopLogprobs()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withMaxCompletionTokens(fromOptions.getMaxCompletionTokens()) - .withN(fromOptions.getN()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withStreamUsage(fromOptions.getStreamUsage()) - .withSeed(fromOptions.getSeed()) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTools(fromOptions.getTools()) - .withToolChoice(fromOptions.getToolChoice()) - .withUser(fromOptions.getUser()) - .withParallelToolCalls(fromOptions.getParallelToolCalls()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withHttpHeaders(fromOptions.getHttpHeaders()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); - } - @Override public int hashCode() { return Objects.hash(this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, @@ -631,10 +477,12 @@ public int hashCode() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } OpenAiChatOptions other = (OpenAiChatOptions) o; return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) && Objects.equals(this.logitBias, other.logitBias) && Objects.equals(this.logprobs, other.logprobs) @@ -660,4 +508,159 @@ public String toString() { return "OpenAiChatOptions: " + ModelOptionsUtils.toJsonString(this); } + public static class Builder { + + protected OpenAiChatOptions options; + + public Builder() { + this.options = new OpenAiChatOptions(); + } + + public Builder(OpenAiChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withModel(OpenAiApi.ChatModel openAiChatModel) { + this.options.model = openAiChatModel.getName(); + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withLogitBias(Map logitBias) { + this.options.logitBias = logitBias; + return this; + } + + public Builder withLogprobs(Boolean logprobs) { + this.options.logprobs = logprobs; + return this; + } + + public Builder withTopLogprobs(Integer topLogprobs) { + this.options.topLogprobs = topLogprobs; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withMaxCompletionTokens(Integer maxCompletionTokens) { + this.options.maxCompletionTokens = maxCompletionTokens; + return this; + } + + public Builder withN(Integer n) { + this.options.n = n; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withStreamUsage(boolean enableStreamUsage) { + this.options.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null; + return this; + } + + public Builder withSeed(Integer seed) { + this.options.seed = seed; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withUser(String user) { + this.options.user = user; + return this; + } + + public Builder withParallelToolCalls(Boolean parallelToolCalls) { + this.options.parallelToolCalls = parallelToolCalls; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withHttpHeaders(Map httpHeaders) { + this.options.httpHeaders = httpHeaders; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public OpenAiChatOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java index 89f61224ad5..a0824cf5674 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -27,9 +31,9 @@ import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; -import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; @@ -40,8 +44,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * Open AI Embedding Model implementation. * diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java index dbd0a2979d0..64f173eaf05 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import com.fasterxml.jackson.annotation.JsonInclude; @@ -51,40 +52,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected OpenAiEmbeddingOptions options; - - public Builder() { - this.options = new OpenAiEmbeddingOptions(); - } - - public Builder withModel(String model) { - this.options.setModel(model); - return this; - } - - public Builder withEncodingFormat(String encodingFormat) { - this.options.setEncodingFormat(encodingFormat); - return this; - } - - public Builder withDimensions(Integer dimensions) { - this.options.dimensions = dimensions; - return this; - } - - public Builder withUser(String user) { - this.options.setUser(user); - return this; - } - - public OpenAiEmbeddingOptions build() { - return this.options; - } - - } - @Override public String getModel() { return this.model; @@ -119,4 +86,38 @@ public void setUser(String user) { this.user = user; } + public static class Builder { + + protected OpenAiEmbeddingOptions options; + + public Builder() { + this.options = new OpenAiEmbeddingOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withEncodingFormat(String encodingFormat) { + this.options.setEncodingFormat(encodingFormat); + return this; + } + + public Builder withDimensions(Integer dimensions) { + this.options.dimensions = dimensions; + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public OpenAiEmbeddingOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java index 3fabb9f4976..1da5a9b8694 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImageModel; @@ -26,8 +30,8 @@ import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponseMetadata; import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; -import org.springframework.ai.image.observation.ImageModelObservationConvention; import org.springframework.ai.image.observation.ImageModelObservationContext; +import org.springframework.ai.image.observation.ImageModelObservationConvention; import org.springframework.ai.image.observation.ImageModelObservationDocumentation; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiImageApi; @@ -39,8 +43,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * OpenAiImageModel is a class that implements the ImageModel interface. It provides a * client for calling the OpenAI image generation API. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java index 949614dae5c..1bab88096dd 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; -import java.util.Objects; - /** * OpenAI Image API options. OpenAiImageOptions.java * @@ -99,60 +100,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final OpenAiImageOptions options; - - private Builder() { - this.options = new OpenAiImageOptions(); - } - - public Builder withN(Integer n) { - options.setN(n); - return this; - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withQuality(String quality) { - options.setQuality(quality); - return this; - } - - public Builder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); - return this; - } - - public Builder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public Builder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public Builder withStyle(String style) { - options.setStyle(style); - return this; - } - - public Builder withUser(String user) { - options.setUser(user); - return this; - } - - public OpenAiImageOptions build() { - return options; - } - - } - @Override public Integer getN() { return this.n; @@ -181,7 +128,7 @@ public void setQuality(String quality) { @Override public String getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(String responseFormat) { @@ -247,10 +194,6 @@ public void setUser(String user) { this.user = user; } - public void setSize(String size) { - this.size = size; - } - public String getSize() { if (this.size != null) { return this.size; @@ -258,28 +201,91 @@ public String getSize() { return (this.width != null && this.height != null) ? this.width + "x" + this.height : null; } + public void setSize(String size) { + this.size = size; + } + @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof OpenAiImageOptions that)) + } + if (!(o instanceof OpenAiImageOptions that)) { return false; - return Objects.equals(n, that.n) && Objects.equals(model, that.model) && Objects.equals(width, that.width) - && Objects.equals(height, that.height) && Objects.equals(quality, that.quality) - && Objects.equals(responseFormat, that.responseFormat) && Objects.equals(size, that.size) - && Objects.equals(style, that.style) && Objects.equals(user, that.user); + } + return Objects.equals(this.n, that.n) && Objects.equals(this.model, that.model) + && Objects.equals(this.width, that.width) && Objects.equals(this.height, that.height) + && Objects.equals(this.quality, that.quality) + && Objects.equals(this.responseFormat, that.responseFormat) && Objects.equals(this.size, that.size) + && Objects.equals(this.style, that.style) && Objects.equals(this.user, that.user); } @Override public int hashCode() { - return Objects.hash(n, model, width, height, quality, responseFormat, size, style, user); + return Objects.hash(this.n, this.model, this.width, this.height, this.quality, this.responseFormat, this.size, + this.style, this.user); } @Override public String toString() { - return "OpenAiImageOptions{" + "n=" + n + ", model='" + model + '\'' + ", width=" + width + ", height=" + height - + ", quality='" + quality + '\'' + ", responseFormat='" + responseFormat + '\'' + ", size='" + size - + '\'' + ", style='" + style + '\'' + ", user='" + user + '\'' + '}'; + return "OpenAiImageOptions{" + "n=" + this.n + ", model='" + this.model + '\'' + ", width=" + this.width + + ", height=" + this.height + ", quality='" + this.quality + '\'' + ", responseFormat='" + + this.responseFormat + '\'' + ", size='" + this.size + '\'' + ", style='" + this.style + '\'' + + ", user='" + this.user + '\'' + '}'; + } + + public static class Builder { + + private final OpenAiImageOptions options; + + private Builder() { + this.options = new OpenAiImageOptions(); + } + + public Builder withN(Integer n) { + this.options.setN(n); + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withQuality(String quality) { + this.options.setQuality(quality); + return this; + } + + public Builder withResponseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + + public Builder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public Builder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public Builder withStyle(String style) { + this.options.setStyle(style); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public OpenAiImageOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java index dbf662af08d..f6719710c05 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,28 @@ package org.springframework.ai.openai; +import java.util.ArrayList; +import java.util.List; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.moderation.*; +import org.springframework.ai.moderation.Categories; +import org.springframework.ai.moderation.CategoryScores; +import org.springframework.ai.moderation.Generation; +import org.springframework.ai.moderation.Moderation; +import org.springframework.ai.moderation.ModerationModel; +import org.springframework.ai.moderation.ModerationOptions; +import org.springframework.ai.moderation.ModerationPrompt; +import org.springframework.ai.moderation.ModerationResponse; +import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.openai.api.OpenAiModerationApi; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.List; - /** * OpenAiModerationModel is a class that implements the ModerationModel interface. It * provides a client for calling the OpenAI moderation generation API. @@ -40,12 +49,12 @@ public class OpenAiModerationModel implements ModerationModel { private final Logger logger = LoggerFactory.getLogger(getClass()); - private OpenAiModerationOptions defaultOptions; - private final OpenAiModerationApi openAiModerationApi; private final RetryTemplate retryTemplate; + private OpenAiModerationOptions defaultOptions; + public OpenAiModerationModel(OpenAiModerationApi openAiModerationApi) { this(openAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE); } @@ -97,7 +106,7 @@ private ModerationResponse convertResponse( OpenAiModerationApi.OpenAiModerationRequest openAiModerationRequest) { OpenAiModerationApi.OpenAiModerationResponse moderationApiResponse = moderationResponseEntity.getBody(); if (moderationApiResponse == null) { - logger.warn("No moderation response returned for request: {}", openAiModerationRequest); + this.logger.warn("No moderation response returned for request: {}", openAiModerationRequest); return new ModerationResponse(new Generation()); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationOptions.java index 9abacec51d1..49688231357 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.moderation.ModerationOptions; import org.springframework.ai.openai.api.OpenAiModerationApi; @@ -39,6 +41,15 @@ public static Builder builder() { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + public static class Builder { private final OpenAiModerationOptions options; @@ -48,23 +59,14 @@ private Builder() { } public Builder withModel(String model) { - options.setModel(model); + this.options.setModel(model); return this; } public OpenAiModerationOptions build() { - return options; + return this.options; } } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java index b355e1b2085..3a4fe1fa5e2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.aot; +import java.util.Set; + import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiImageApi; @@ -25,8 +28,6 @@ import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; -import java.util.Set; - import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 657f37a1366..b4dd95d29fe 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.List; @@ -21,6 +22,12 @@ import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.common.OpenAiApiConstants; @@ -39,13 +46,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * Single class implementation of the * OpenAI Chat Completion @@ -74,6 +74,8 @@ public class OpenAiApi { private final WebClient webClient; + private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper(); + /** * Create a new chat completion api with base URL set to https://api.openai.com * @param apiKey OpenAI apiKey. @@ -173,6 +175,155 @@ public OpenAiApi(String baseUrl, String apiKey, MultiValueMap he .build();// @formatter:on } + public static String getTextContent(List content) { + return content.stream() + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + return chatCompletionEntity(chatRequest, new LinkedMultiValueMap<>()); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @param additionalHttpHeader Optional, additional HTTP headers to be added to the + * request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest, + MultiValueMap additionalHttpHeader) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); + + return this.restClient.post() + .uri(this.completionsPath) + .headers(headers -> headers.addAll(additionalHttpHeader)) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @param additionalHttpHeader Optional, additional HTTP headers to be added to the + * request. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest, + MultiValueMap additionalHttpHeader) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri(this.completionsPath) + .headers(headers -> headers.addAll(additionalHttpHeader)) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + // cancels the flux stream after the "[DONE]" is received. + .takeUntil(SSE_DONE_PREDICATE) + // filters out the "[DONE]" message. + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + // Detect is the chunk is part of a streaming function call. + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + // Group all chunks belonging to the same function call. + // Flux -> Flux> + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + // Reduce the inner Flux window into a single + // Mono, + // Flux> -> Flux> + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + // Flux> -> Flux + .flatMap(mono -> mono); + } + + /** + * Creates an embedding vector representing the input text or token array. + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + * @param Type of the entity in the data list. Can be a {@link String} or + * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single + * request, You can pass a {@link List} of {@link String} or {@link List} of + * {@link List} of tokens. For example: + * + *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
+ */ + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + // The input must not exceed the max input tokens for the model (8192 tokens for + // text-embedding-ada-002), cannot + // be an empty string, and any array must be 2048 dimensions or less. + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 2048, "The list must be 2048 dimensions or less"); + Assert.isTrue( + list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, + "The input must be either a String, or a List of Strings or list of list of integers."); + } + + return this.restClient.post() + .uri(this.embeddingsPath) + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + + }); + } + /** * OpenAI Chat Completion Models: * @@ -296,7 +447,7 @@ public enum ChatModel implements ChatModelDescription { } public String getValue() { - return value; + return this.value; } @Override @@ -306,6 +457,79 @@ public String getName() { } + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") + STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") + LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") + CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") + TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") + FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") + TOOL_CALL + + } + + /** + * OpenAI Embeddings Models: + *
Embeddings. + */ + public enum EmbeddingModel { + + /** + * Most capable embedding model for both english and non-english tasks. DIMENSION: + * 3072 + */ + TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"), + + /** + * Increased performance over 2nd generation ada embedding model. DIMENSION: 1536 + */ + TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"), + + /** + * Most capable 2nd generation embedding model, replacing 16 first generation + * models. DIMENSION: 1536 + */ + TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + /** * Represents a tool the model may call. Currently, only functions are supported as a * tool. @@ -523,9 +747,9 @@ public ChatCompletionRequest(List messages, Boolean strea * @return A new {@link ChatCompletionRequest} with the specified stream options. */ public ChatCompletionRequest withStreamOptions(StreamOptions streamOptions) { - return new ChatCompletionRequest(messages, model, frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, maxCompletionTokens, n, presencePenalty, - responseFormat, seed, stop, stream, streamOptions, temperature, topP, - tools, toolChoice, parallelToolCalls, user); + return new ChatCompletionRequest(this.messages, this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, + this.responseFormat, this.seed, this.stop, this.stream, streamOptions, this.temperature, this.topP, + this.tools, this.toolChoice, this.parallelToolCalls, this.user); } /** @@ -559,7 +783,20 @@ public static Object FUNCTION(String functionName) { public record ResponseFormat( @JsonProperty("type") Type type, @JsonProperty("json_schema") JsonSchema jsonSchema ) { - + + public ResponseFormat(Type type) { + this(type, (JsonSchema) null); + } + + public ResponseFormat(Type type, String schema) { + this(type, "custom_schema", schema, true); + } + + @ConstructorBinding + public ResponseFormat(Type type, String name, String schema, Boolean strict) { + this(type, StringUtils.hasText(schema)? new JsonSchema(name, schema, strict): null); + } + public enum Type { /** * Generates a text response. (default) @@ -604,19 +841,6 @@ public JsonSchema(String name, String schema, Boolean strict) { } } - public ResponseFormat(Type type) { - this(type, (JsonSchema) null); - } - - public ResponseFormat(Type type, String schema) { - this(type, "custom_schema", schema, true); - } - - @ConstructorBinding - public ResponseFormat(Type type, String name, String schema, Boolean strict) { - this(type, StringUtils.hasText(schema)? new JsonSchema(name, schema, strict): null); - } - } /** @@ -658,6 +882,16 @@ public record ChatCompletionMessage(// @formatter:off @JsonProperty("tool_calls") List toolCalls, @JsonProperty("refusal") String refusal) {// @formatter:on + /** + * Create a chat completion message with the given content and role. All other + * fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null, null); + } + /** * Get message content as String. */ @@ -671,16 +905,6 @@ public String content() { throw new IllegalStateException("The content is not a string!"); } - /** - * Create a chat completion message with the given content and role. All other - * fields are null. - * @param content The contents of the message. - * @param role The role of the author of this message. - */ - public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null, null); - } - /** * The role of the author of this message. */ @@ -725,19 +949,6 @@ public record MediaContent(// @formatter:off @JsonProperty("text") String text, @JsonProperty("image_url") ImageUrl imageUrl) { // @formatter:on - /** - * @param url Either a URL of the image or the base64 encoded image data. The - * base64 encoded image data must have a special prefix in the following - * format: "data:{mimetype};base64,{base64-encoded-image-data}". - * @param detail Specifies the detail level of the image. - */ - @JsonInclude(Include.NON_NULL) - public record ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") String detail) { - - public ImageUrl(String url) { - this(url, null); - } - } /** * Shortcut constructor for a text content. @@ -754,6 +965,22 @@ public MediaContent(String text) { public MediaContent(ImageUrl imageUrl) { this("image_url", null, imageUrl); } + + /** + * @param url Either a URL of the image or the base64 encoded image data. The + * base64 encoded image data must have a special prefix in the following + * format: "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + + } + } /** @@ -777,6 +1004,7 @@ public record ToolCall(// @formatter:off public ToolCall(String id, String type, ChatCompletionFunction function) { this(null, id, type, function); } + } /** @@ -791,50 +1019,6 @@ public record ChatCompletionFunction(// @formatter:off @JsonProperty("name") String name, @JsonProperty("arguments") String arguments) {// @formatter:on } - } - - public static String getTextContent(List content) { - return content.stream() - .filter(c -> "text".equals(c.type())) - .map(ChatCompletionMessage.MediaContent::text) - .reduce("", (a, b) -> a + b); - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") - STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") - LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("content_filter") - CONTENT_FILTER, - /** - * The model called a tool. - */ - @JsonProperty("tool_calls") - TOOL_CALLS, - /** - * (deprecated) The model called a function. - */ - @JsonProperty("function_call") - FUNCTION_CALL, - /** - * Only for compatibility with Mistral AI API. - */ - @JsonProperty("tool_call") - TOOL_CALL } @@ -880,6 +1064,7 @@ public record Choice(// @formatter:off @JsonProperty("logprobs") LogProbs logprobs) {// @formatter:on } + } /** @@ -928,9 +1113,13 @@ public record TopLogProbs(// @formatter:off @JsonProperty("logprob") Float logprob, @JsonProperty("bytes") List probBytes) {// @formatter:on } + } + } + // Embeddings API + /** * Usage statistics for the completion request. * @@ -1018,144 +1207,6 @@ public record ChunkChoice(// @formatter:off @JsonProperty("delta") ChatCompletionMessage delta, @JsonProperty("logprobs") LogProbs logprobs) {// @formatter:on } - } - - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code - * and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - return chatCompletionEntity(chatRequest, new LinkedMultiValueMap<>()); - } - - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @param additionalHttpHeader Optional, additional HTTP headers to be added to the - * request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code - * and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); - - return this.restClient.post() - .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); - } - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @param additionalHttpHeader Optional, additional HTTP headers to be added to the - * request. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - // cancels the flux stream after the "[DONE]" is received. - .takeUntil(SSE_DONE_PREDICATE) - // filters out the "[DONE]" message. - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - // Detect is the chunk is part of a streaming function call. - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - // Group all chunks belonging to the same function call. - // Flux -> Flux> - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - // Reduce the inner Flux window into a single - // Mono, - // Flux> -> Flux> - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null, null), - (previous, current) -> this.chunkMerger.merge(previous, current)); - return List.of(monoChunk); - }) - // Flux> -> Flux - .flatMap(mono -> mono); - } - - // Embeddings API - - /** - * OpenAI Embeddings Models: - * Embeddings. - */ - public enum EmbeddingModel { - - /** - * Most capable embedding model for both english and non-english tasks. DIMENSION: - * 3072 - */ - TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"), - - /** - * Increased performance over 2nd generation ada embedding model. DIMENSION: 1536 - */ - TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"), - - /** - * Most capable 2nd generation embedding model, replacing 16 first generation - * models. DIMENSION: 1536 - */ - TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"); - - public final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } } @@ -1183,6 +1234,7 @@ public record Embedding(// @formatter:off public Embedding(Integer index, float[] embedding) { this(index, embedding, "embedding"); } + } /** @@ -1227,6 +1279,7 @@ public EmbeddingRequest(T input, String model) { public EmbeddingRequest(T input) { this(input, DEFAULT_EMBEDDING_MODEL); } + } /** @@ -1246,45 +1299,4 @@ public record EmbeddingList(// @formatter:off @JsonProperty("usage") Usage usage) {// @formatter:on } - /** - * Creates an embedding vector representing the input text or token array. - * @param embeddingRequest The embedding request. - * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - * @param Type of the entity in the data list. Can be a {@link String} or - * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single - * request, You can pass a {@link List} of {@link String} or {@link List} of - * {@link List} of tokens. For example: - * - *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
- */ - public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple - // inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.input(), "The input can not be null."); - Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, - "The input must be either a String, or a List of Strings or List of List of integers."); - - // The input must not exceed the max input tokens for the model (8192 tokens for - // text-embedding-ada-002), cannot - // be an empty string, and any array must be 2048 dimensions or less. - if (embeddingRequest.input() instanceof List list) { - Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); - Assert.isTrue(list.size() <= 2048, "The list must be 2048 dimensions or less"); - Assert.isTrue( - list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, - "The input must be either a String, or a List of Strings or list of list of integers."); - } - - return this.restClient.post() - .uri(this.embeddingsPath) - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index 72161328ffe..a217026093f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.List; import java.util.Map; import java.util.function.Consumer; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.io.ByteArrayResource; @@ -33,13 +40,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * Turn audio into text or text into audio. Based on * OpenAI Audio @@ -124,6 +124,125 @@ public OpenAiAudioApi(String baseUrl, String apiKey, MultiValueMap createSpeech(SpeechRequest requestBody) { + return this.restClient.post().uri("/v1/audio/speech").body(requestBody).retrieve().toEntity(byte[].class); + } + + /** + * Streams audio generated from the input text. + * + * This method sends a POST request to the OpenAI API to generate audio from the + * provided text. The audio is streamed back as a Flux of ResponseEntity objects, each + * containing a byte array of the audio data. + * @param requestBody The request body containing the details for the audio + * generation, such as the input text, model, voice, and response format. + * @return A Flux of ResponseEntity objects, each containing a byte array of the audio + * data. + */ + public Flux> stream(SpeechRequest requestBody) { + + return this.webClient.post() + .uri("/v1/audio/speech") + .body(Mono.just(requestBody), SpeechRequest.class) + .accept(MediaType.APPLICATION_OCTET_STREAM) + .exchangeToFlux(clientResponse -> { + HttpHeaders headers = clientResponse.headers().asHttpHeaders(); + return clientResponse.bodyToFlux(byte[].class) + .map(bytes -> ResponseEntity.ok().headers(headers).body(bytes)); + }); + } + + /** + * Transcribes audio into the input language. + * @param requestBody The request body. + * @return Response entity containing the transcribed text in either json or text + * format. + */ + public ResponseEntity createTranscription(TranscriptionRequest requestBody) { + return createTranscription(requestBody, requestBody.responseFormat().getResponseType()); + } + + /** + * Transcribes audio into the input language. The response type is specified by the + * responseType parameter. + * @param The response type. + * @param requestBody The request body. + * @param responseType The response type class. + * @return Response entity containing the transcribed text in the responseType format. + */ + public ResponseEntity createTranscription(TranscriptionRequest requestBody, Class responseType) { + + MultiValueMap multipartBody = new LinkedMultiValueMap<>(); + multipartBody.add("file", new ByteArrayResource(requestBody.file()) { + + @Override + public String getFilename() { + return "audio.webm"; + } + }); + multipartBody.add("model", requestBody.model()); + multipartBody.add("language", requestBody.language()); + multipartBody.add("prompt", requestBody.prompt()); + multipartBody.add("response_format", requestBody.responseFormat().getValue()); + multipartBody.add("temperature", requestBody.temperature()); + if (requestBody.granularityType() != null) { + Assert.isTrue(requestBody.responseFormat() == TranscriptResponseFormat.VERBOSE_JSON, + "response_format must be set to verbose_json to use timestamp granularities."); + multipartBody.add("timestamp_granularities[]", requestBody.granularityType().getValue()); + } + + return this.restClient.post() + .uri("/v1/audio/transcriptions") + .body(multipartBody) + .retrieve() + .toEntity(responseType); + } + + /** + * Translates audio into English. + * @param requestBody The request body. + * @return Response entity containing the transcribed text in either json or text + * format. + */ + public ResponseEntity createTranslation(TranslationRequest requestBody) { + return createTranslation(requestBody, requestBody.responseFormat().getResponseType()); + } + + /** + * Translates audio into English. The response type is specified by the responseType + * parameter. + * @param The response type. + * @param requestBody The request body. + * @param responseType The response type class. + * @return Response entity containing the transcribed text in the responseType format. + */ + public ResponseEntity createTranslation(TranslationRequest requestBody, Class responseType) { + + MultiValueMap multipartBody = new LinkedMultiValueMap<>(); + multipartBody.add("file", new ByteArrayResource(requestBody.file()) { + + @Override + public String getFilename() { + return "audio.webm"; + } + }); + multipartBody.add("model", requestBody.model()); + multipartBody.add("prompt", requestBody.prompt()); + multipartBody.add("response_format", requestBody.responseFormat().getValue()); + multipartBody.add("temperature", requestBody.temperature()); + + return this.restClient.post() + .uri("/v1/audio/translations") + .body(multipartBody) + .retrieve() + .toEntity(responseType); + } + /** * TTS is an AI model that converts text to natural sounding spoken text. We offer two * different model variates, tts-1 is optimized for real time text to speech use cases @@ -156,6 +275,69 @@ public String getValue() { } + /** + * Whisper is a + * general-purpose speech recognition model. It is trained on a large dataset of + * diverse audio and is also a multi-task model that can perform multilingual speech + * recognition as well as speech translation and language identification. The Whisper + * v2-large model is currently available through our API with the whisper-1 model + * name. + */ + public enum WhisperModel { + + // @formatter:off + @JsonProperty("whisper-1") WHISPER_1("whisper-1"); + // @formatter:on + + public final String value; + + WhisperModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + /** + * The format of the transcript and translation outputs, in one of these options: + * json, text, srt, verbose_json, or vtt. Defaults to json. + */ + public enum TranscriptResponseFormat { + + // @formatter:off + @JsonProperty("json") JSON("json", StructuredResponse.class), + @JsonProperty("text") TEXT("text", String.class), + @JsonProperty("srt") SRT("srt", String.class), + @JsonProperty("verbose_json") VERBOSE_JSON("verbose_json", StructuredResponse.class), + @JsonProperty("vtt") VTT("vtt", String.class); + // @formatter:on + + public final String value; + + public final Class responseType; + + TranscriptResponseFormat(String value, Class responseType) { + this.value = value; + this.responseType = responseType; + } + + public boolean isJsonType() { + return this == JSON || this == VERBOSE_JSON; + } + + public String getValue() { + return this.value; + } + + public Class getResponseType() { + return this.responseType; + } + + } + /** * Request to generates audio from the input text. Reference: * Create @@ -181,12 +363,16 @@ public record SpeechRequest( @JsonProperty("speed") Float speed) { // @formatter:on + public static Builder builder() { + return new Builder(); + } + /** * The voice to use for synthesis. */ public enum Voice { - // @formatter:off + // @formatter:off @JsonProperty("alloy") ALLOY("alloy"), @JsonProperty("echo") ECHO("echo"), @JsonProperty("fable") FABLE("fable"), @@ -232,10 +418,6 @@ public String getValue() { } - public static Builder builder() { - return new Builder(); - } - /** * Builder for the SpeechRequest. */ @@ -277,38 +459,13 @@ public Builder withSpeed(Float speed) { } public SpeechRequest build() { - Assert.hasText(model, "model must not be empty"); - Assert.hasText(input, "input must not be empty"); + Assert.hasText(this.model, "model must not be empty"); + Assert.hasText(this.input, "input must not be empty"); return new SpeechRequest(this.model, this.input, this.voice, this.responseFormat, this.speed); } } - } - - /** - * Whisper is a - * general-purpose speech recognition model. It is trained on a large dataset of - * diverse audio and is also a multi-task model that can perform multilingual speech - * recognition as well as speech translation and language identification. The Whisper - * v2-large model is currently available through our API with the whisper-1 model - * name. - */ - public enum WhisperModel { - - // @formatter:off - @JsonProperty("whisper-1") WHISPER_1("whisper-1"); - // @formatter:on - - public final String value; - - WhisperModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } } @@ -347,6 +504,10 @@ public record TranscriptionRequest( @JsonProperty("timestamp_granularities") GranularityType granularityType) { // @formatter:on + public static Builder builder() { + return new Builder(); + } + public enum GranularityType { // @formatter:off @@ -366,10 +527,6 @@ public String getValue() { } - public static Builder builder() { - return new Builder(); - } - public static class Builder { private byte[] file; @@ -431,42 +588,6 @@ public TranscriptionRequest build() { } } - } - - /** - * The format of the transcript and translation outputs, in one of these options: - * json, text, srt, verbose_json, or vtt. Defaults to json. - */ - public enum TranscriptResponseFormat { - - // @formatter:off - @JsonProperty("json") JSON("json", StructuredResponse.class), - @JsonProperty("text") TEXT("text", String.class), - @JsonProperty("srt") SRT("srt", String.class), - @JsonProperty("verbose_json") VERBOSE_JSON("verbose_json", StructuredResponse.class), - @JsonProperty("vtt") VTT("vtt", String.class); - // @formatter:on - - public final String value; - - public final Class responseType; - - public boolean isJsonType() { - return this == JSON || this == VERBOSE_JSON; - } - - TranscriptResponseFormat(String value, Class responseType) { - this.value = value; - this.responseType = responseType; - } - - public String getValue() { - return this.value; - } - - public Class getResponseType() { - return this.responseType; - } } @@ -537,15 +658,16 @@ public Builder withTemperature(Float temperature) { } public TranslationRequest build() { - Assert.notNull(file, "file must not be null"); - Assert.hasText(model, "model must not be empty"); - Assert.notNull(responseFormat, "response_format must not be null"); + Assert.notNull(this.file, "file must not be null"); + Assert.hasText(this.model, "model must not be empty"); + Assert.notNull(this.responseFormat, "response_format must not be null"); return new TranslationRequest(this.file, this.model, this.prompt, this.responseFormat, this.temperature); } } + } /** @@ -619,123 +741,7 @@ public record Segment( @JsonProperty("no_speech_prob") Float noSpeechProb) { // @formatter:on } - } - - /** - * Request to generates audio from the input text. - * @param requestBody The request body. - * @return Response entity containing the audio binary. - */ - public ResponseEntity createSpeech(SpeechRequest requestBody) { - return this.restClient.post().uri("/v1/audio/speech").body(requestBody).retrieve().toEntity(byte[].class); - } - - /** - * Streams audio generated from the input text. - * - * This method sends a POST request to the OpenAI API to generate audio from the - * provided text. The audio is streamed back as a Flux of ResponseEntity objects, each - * containing a byte array of the audio data. - * @param requestBody The request body containing the details for the audio - * generation, such as the input text, model, voice, and response format. - * @return A Flux of ResponseEntity objects, each containing a byte array of the audio - * data. - */ - public Flux> stream(SpeechRequest requestBody) { - - return webClient.post() - .uri("/v1/audio/speech") - .body(Mono.just(requestBody), SpeechRequest.class) - .accept(MediaType.APPLICATION_OCTET_STREAM) - .exchangeToFlux(clientResponse -> { - HttpHeaders headers = clientResponse.headers().asHttpHeaders(); - return clientResponse.bodyToFlux(byte[].class) - .map(bytes -> ResponseEntity.ok().headers(headers).body(bytes)); - }); - } - - /** - * Transcribes audio into the input language. - * @param requestBody The request body. - * @return Response entity containing the transcribed text in either json or text - * format. - */ - public ResponseEntity createTranscription(TranscriptionRequest requestBody) { - return createTranscription(requestBody, requestBody.responseFormat().getResponseType()); - } - - /** - * Transcribes audio into the input language. The response type is specified by the - * responseType parameter. - * @param The response type. - * @param requestBody The request body. - * @param responseType The response type class. - * @return Response entity containing the transcribed text in the responseType format. - */ - public ResponseEntity createTranscription(TranscriptionRequest requestBody, Class responseType) { - - MultiValueMap multipartBody = new LinkedMultiValueMap<>(); - multipartBody.add("file", new ByteArrayResource(requestBody.file()) { - @Override - public String getFilename() { - return "audio.webm"; - } - }); - multipartBody.add("model", requestBody.model()); - multipartBody.add("language", requestBody.language()); - multipartBody.add("prompt", requestBody.prompt()); - multipartBody.add("response_format", requestBody.responseFormat().getValue()); - multipartBody.add("temperature", requestBody.temperature()); - if (requestBody.granularityType() != null) { - Assert.isTrue(requestBody.responseFormat() == TranscriptResponseFormat.VERBOSE_JSON, - "response_format must be set to verbose_json to use timestamp granularities."); - multipartBody.add("timestamp_granularities[]", requestBody.granularityType().getValue()); - } - return this.restClient.post() - .uri("/v1/audio/transcriptions") - .body(multipartBody) - .retrieve() - .toEntity(responseType); - } - - /** - * Translates audio into English. - * @param requestBody The request body. - * @return Response entity containing the transcribed text in either json or text - * format. - */ - public ResponseEntity createTranslation(TranslationRequest requestBody) { - return createTranslation(requestBody, requestBody.responseFormat().getResponseType()); - } - - /** - * Translates audio into English. The response type is specified by the responseType - * parameter. - * @param The response type. - * @param requestBody The request body. - * @param responseType The response type class. - * @return Response entity containing the transcribed text in the responseType format. - */ - public ResponseEntity createTranslation(TranslationRequest requestBody, Class responseType) { - - MultiValueMap multipartBody = new LinkedMultiValueMap<>(); - multipartBody.add("file", new ByteArrayResource(requestBody.file()) { - @Override - public String getFilename() { - return "audio.webm"; - } - }); - multipartBody.add("model", requestBody.model()); - multipartBody.add("prompt", requestBody.prompt()); - multipartBody.add("response_format", requestBody.responseFormat().getValue()); - multipartBody.add("temperature", requestBody.temperature()); - - return this.restClient.post() - .uri("/v1/audio/translations") - .body(multipartBody) - .retrieve() - .toEntity(responseType); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index 698bbbae5dd..c534054079e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.MediaType; @@ -28,9 +32,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * OpenAI Image API. * @@ -95,6 +96,17 @@ public OpenAiImageApi(String baseUrl, String apiKey, MultiValueMap createImage(OpenAiImageRequest openAiImageRequest) { + Assert.notNull(openAiImageRequest, "Image request cannot be null."); + Assert.hasLength(openAiImageRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("v1/images/generations") + .body(openAiImageRequest) + .retrieve() + .toEntity(OpenAiImageResponse.class); + } + /** * OpenAI Image API model. * DALL·E @@ -147,24 +159,12 @@ public record OpenAiImageResponse( @JsonProperty("created") Long created, @JsonProperty("data") List data) { } - - @JsonInclude(JsonInclude.Include.NON_NULL) - public record Data( - @JsonProperty("url") String url, - @JsonProperty("b64_json") String b64Json, - @JsonProperty("revised_prompt") String revisedPrompt) { - } // @formatter:onn - public ResponseEntity createImage(OpenAiImageRequest openAiImageRequest) { - Assert.notNull(openAiImageRequest, "Image request cannot be null."); - Assert.hasLength(openAiImageRequest.prompt(), "Prompt cannot be empty."); + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("url") String url, @JsonProperty("b64_json") String b64Json, + @JsonProperty("revised_prompt") String revisedPrompt) { - return this.restClient.post() - .uri("v1/images/generations") - .body(openAiImageRequest) - .retrieve() - .toEntity(OpenAiImageResponse.class); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java index 092cc62bc96..02e2b3ca109 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; + import org.springframework.ai.retry.RetryUtils; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; @@ -22,11 +28,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; - /** * OpenAI Moderation API. * @@ -36,10 +37,10 @@ */ public class OpenAiModerationApi { - private static final String DEFAULT_BASE_URL = "https://api.openai.com"; - public static final String DEFAULT_MODERATION_MODEL = "text-moderation-latest"; + private static final String DEFAULT_BASE_URL = "https://api.openai.com"; + private final RestClient restClient; private final ObjectMapper objectMapper; @@ -63,6 +64,17 @@ public OpenAiModerationApi(String baseUrl, String openAiToken, RestClient.Builde }).defaultStatusHandler(responseErrorHandler).build(); } + public ResponseEntity createModeration(OpenAiModerationRequest openAiModerationRequest) { + Assert.notNull(openAiModerationRequest, "Moderation request cannot be null."); + Assert.hasLength(openAiModerationRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("v1/moderations") + .body(openAiModerationRequest) + .retrieve() + .toEntity(OpenAiModerationResponse.class); + } + // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) public record OpenAiModerationRequest ( @@ -82,6 +94,7 @@ public record OpenAiModerationResponse( @JsonProperty("results") OpenAiModerationResult[] results) { } + @JsonInclude(JsonInclude.Include.NON_NULL) public record OpenAiModerationResult( @JsonProperty("flagged") boolean flagged, @@ -89,6 +102,7 @@ public record OpenAiModerationResult( @JsonProperty("category_scores") CategoryScores categoryScores) { } + @JsonInclude(JsonInclude.Include.NON_NULL) public record Categories( @JsonProperty("sexual") boolean sexual, @@ -119,25 +133,12 @@ public record CategoryScores( @JsonProperty("violence") double violence) { } - - - @JsonInclude(JsonInclude.Include.NON_NULL) - public record Data( - @JsonProperty("url") String url, - @JsonProperty("b64_json") String b64Json, - @JsonProperty("revised_prompt") String revisedPrompt) { - } // @formatter:onn - public ResponseEntity createModeration(OpenAiModerationRequest openAiModerationRequest) { - Assert.notNull(openAiModerationRequest, "Moderation request cannot be null."); - Assert.hasLength(openAiModerationRequest.prompt(), "Prompt cannot be empty."); + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("url") String url, @JsonProperty("b64_json") String b64Json, + @JsonProperty("revised_prompt") String revisedPrompt) { - return this.restClient.post() - .uri("v1/moderations") - .body(openAiModerationRequest) - .retrieve() - .toEntity(OpenAiModerationResponse.class); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index 02bfd310800..9bdd2ea1859 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.ArrayList; @@ -21,13 +22,13 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionFinishReason; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; -import org.springframework.ai.openai.api.OpenAiApi.LogProbs; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCompletionFunction; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.openai.api.OpenAiApi.LogProbs; import org.springframework.ai.openai.api.OpenAiApi.Usage; import org.springframework.util.CollectionUtils; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiClientErrorException.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiClientErrorException.java index a53bc0bf6fb..7d5e961714b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiClientErrorException.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiClientErrorException.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api.common; /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiConstants.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiConstants.java index ebc4544218c..81051cf7b72 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiConstants.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiConstants.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai.api.common; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java index 5921940212c..93ae1cba3c5 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.speech; +import java.util.Arrays; +import java.util.Objects; + import org.springframework.ai.model.ModelResult; import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechMetadata; import org.springframework.lang.Nullable; -import java.util.Arrays; -import java.util.Objects; - /** * The Speech class represents the result of speech synthesis from an AI model. It * implements the ModelResult interface with the output type of byte array. @@ -46,7 +47,7 @@ public byte[] getOutput() { @Override public OpenAiAudioSpeechMetadata getMetadata() { - return speechMetadata != null ? speechMetadata : OpenAiAudioSpeechMetadata.NULL; + return this.speechMetadata != null ? this.speechMetadata : OpenAiAudioSpeechMetadata.NULL; } public Speech withSpeechMetadata(@Nullable OpenAiAudioSpeechMetadata speechMetadata) { @@ -56,21 +57,23 @@ public Speech withSpeechMetadata(@Nullable OpenAiAudioSpeechMetadata speechMetad @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Speech that)) + } + if (!(o instanceof Speech that)) { return false; - return Arrays.equals(audio, that.audio) && Objects.equals(speechMetadata, that.speechMetadata); + } + return Arrays.equals(this.audio, that.audio) && Objects.equals(this.speechMetadata, that.speechMetadata); } @Override public int hashCode() { - return Objects.hash(Arrays.hashCode(audio), speechMetadata); + return Objects.hash(Arrays.hashCode(this.audio), this.speechMetadata); } @Override public String toString() { - return "Speech{" + "text=" + audio + ", speechMetadata=" + speechMetadata + '}'; + return "Speech{" + "text=" + this.audio + ", speechMetadata=" + this.speechMetadata + '}'; } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java index dcc96251b63..dde419268b9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.speech; import java.util.Objects; @@ -41,7 +42,7 @@ public SpeechMessage(String text) { * @return the text of this speech message */ public String getText() { - return text; + return this.text; } /** @@ -54,16 +55,18 @@ public void setText(String text) { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof SpeechMessage that)) + } + if (!(o instanceof SpeechMessage that)) { return false; - return Objects.equals(text, that.text); + } + return Objects.equals(this.text, that.text); } @Override public int hashCode() { - return Objects.hash(text); + return Objects.hash(this.text); } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechModel.java index 9d976fd7510..f03370ce434 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java index 8cb21684d65..03fb07d6e89 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.speech; +import java.util.Objects; + import org.springframework.ai.model.ModelOptions; import org.springframework.ai.model.ModelRequest; import org.springframework.ai.openai.OpenAiAudioSpeechOptions; -import java.util.Collections; -import java.util.List; -import java.util.Objects; - /** * The {@link SpeechPrompt} class represents a request to the OpenAI Text-to-Speech (TTS) * API. It contains a list of {@link SpeechMessage} objects, each representing a piece of @@ -33,10 +32,10 @@ */ public class SpeechPrompt implements ModelRequest { - private OpenAiAudioSpeechOptions speechOptions; - private final SpeechMessage message; + private OpenAiAudioSpeechOptions speechOptions; + public SpeechPrompt(String instructions) { this(new SpeechMessage(instructions), OpenAiAudioSpeechOptions.builder().build()); } @@ -61,21 +60,23 @@ public SpeechMessage getInstructions() { @Override public ModelOptions getOptions() { - return speechOptions; + return this.speechOptions; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof SpeechPrompt that)) + } + if (!(o instanceof SpeechPrompt that)) { return false; - return Objects.equals(speechOptions, that.speechOptions) && Objects.equals(message, that.message); + } + return Objects.equals(this.speechOptions, that.speechOptions) && Objects.equals(this.message, that.message); } @Override public int hashCode() { - return Objects.hash(speechOptions, message); + return Objects.hash(this.speechOptions, this.message); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java index 028bbf22834..5b92fe770b1 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package org.springframework.ai.openai.audio.speech; -import org.springframework.ai.model.ModelResponse; -import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; - import java.util.Collections; import java.util.List; import java.util.Objects; +import org.springframework.ai.model.ModelResponse; +import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; + /** * Creates a new instance of SpeechResponse with the given speech result. * @@ -60,32 +60,34 @@ public SpeechResponse(Speech speech, OpenAiAudioSpeechResponseMetadata speechRes @Override public Speech getResult() { - return speech; + return this.speech; } @Override public List getResults() { - return Collections.singletonList(speech); + return Collections.singletonList(this.speech); } @Override public OpenAiAudioSpeechResponseMetadata getMetadata() { - return speechResponseMetadata; + return this.speechResponseMetadata; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof SpeechResponse that)) + } + if (!(o instanceof SpeechResponse that)) { return false; - return Objects.equals(speech, that.speech) - && Objects.equals(speechResponseMetadata, that.speechResponseMetadata); + } + return Objects.equals(this.speech, that.speech) + && Objects.equals(this.speechResponseMetadata, that.speechResponseMetadata); } @Override public int hashCode() { - return Objects.hash(speech, speechResponseMetadata); + return Objects.hash(this.speech, this.speechResponseMetadata); } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechModel.java index a8ae06b0739..92dcfa3473a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,9 +16,10 @@ package org.springframework.ai.openai.audio.speech; -import org.springframework.ai.model.StreamingModel; import reactor.core.publisher.Flux; +import org.springframework.ai.model.StreamingModel; + /** * The {@link StreamingSpeechModel} interface provides a way to interact with the OpenAI * Text-to-Speech (TTS) API using a streaming approach, allowing you to receive the diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java index bc4da401bfe..186095dc9c7 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.metadata; -import org.springframework.ai.image.ImageGenerationMetadata; +package org.springframework.ai.openai.metadata; import java.util.Objects; +import org.springframework.ai.image.ImageGenerationMetadata; + public class OpenAiImageGenerationMetadata implements ImageGenerationMetadata { private String revisedPrompt; @@ -28,26 +29,28 @@ public OpenAiImageGenerationMetadata(String revisedPrompt) { } public String getRevisedPrompt() { - return revisedPrompt; + return this.revisedPrompt; } @Override public String toString() { - return "OpenAiImageGenerationMetadata{" + "revisedPrompt='" + revisedPrompt + '\'' + '}'; + return "OpenAiImageGenerationMetadata{" + "revisedPrompt='" + this.revisedPrompt + '\'' + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof OpenAiImageGenerationMetadata that)) + } + if (!(o instanceof OpenAiImageGenerationMetadata that)) { return false; - return Objects.equals(revisedPrompt, that.revisedPrompt); + } + return Objects.equals(this.revisedPrompt, that.revisedPrompt); } @Override public int hashCode() { - return Objects.hash(revisedPrompt); + return Objects.hash(this.revisedPrompt); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiModerationGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiModerationGenerationMetadata.java index b5622694837..71d1dcee44e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiModerationGenerationMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiModerationGenerationMetadata.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,6 @@ import org.springframework.ai.moderation.ModerationGenerationMetadata; -import java.util.Objects; - public class OpenAiModerationGenerationMetadata implements ModerationGenerationMetadata { public OpenAiModerationGenerationMetadata() { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java index 7f5f214da04..664de40a4a8 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata; import java.time.Duration; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java index 46ec6ffb786..4e32bd15366 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -32,10 +33,6 @@ */ public class OpenAiUsage implements Usage { - public static OpenAiUsage from(OpenAiApi.Usage usage) { - return new OpenAiUsage(usage); - } - private final OpenAiApi.Usage usage; protected OpenAiUsage(OpenAiApi.Usage usage) { @@ -43,6 +40,10 @@ protected OpenAiUsage(OpenAiApi.Usage usage) { this.usage = usage; } + public static OpenAiUsage from(OpenAiApi.Usage usage) { + return new OpenAiUsage(usage); + } + protected OpenAiApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java index 85289d85408..b6de47b4bde 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -28,6 +28,7 @@ public interface OpenAiAudioSpeechMetadata extends ResultMetadata { */ static OpenAiAudioSpeechMetadata create() { return new OpenAiAudioSpeechMetadata() { + }; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java index efcb6ebca74..e90c4097d71 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,13 +19,10 @@ import org.springframework.ai.chat.metadata.EmptyRateLimit; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.model.MutableResponseMetadata; -import org.springframework.ai.model.ResponseMetadata; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import java.util.HashMap; - /** * Audio speech metadata implementation for {@literal OpenAI}. * @@ -34,22 +31,11 @@ */ public class OpenAiAudioSpeechResponseMetadata extends MutableResponseMetadata { - protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }"; - public static final OpenAiAudioSpeechResponseMetadata NULL = new OpenAiAudioSpeechResponseMetadata() { - }; - public static OpenAiAudioSpeechResponseMetadata from(OpenAiAudioApi.StructuredResponse result) { - Assert.notNull(result, "OpenAI speech must not be null"); - OpenAiAudioSpeechResponseMetadata speechResponseMetadata = new OpenAiAudioSpeechResponseMetadata(); - return speechResponseMetadata; - } + }; - public static OpenAiAudioSpeechResponseMetadata from(String result) { - Assert.notNull(result, "OpenAI speech must not be null"); - OpenAiAudioSpeechResponseMetadata speechResponseMetadata = new OpenAiAudioSpeechResponseMetadata(); - return speechResponseMetadata; - } + protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }"; @Nullable private RateLimit rateLimit; @@ -62,6 +48,18 @@ public OpenAiAudioSpeechResponseMetadata(@Nullable RateLimit rateLimit) { this.rateLimit = rateLimit; } + public static OpenAiAudioSpeechResponseMetadata from(OpenAiAudioApi.StructuredResponse result) { + Assert.notNull(result, "OpenAI speech must not be null"); + OpenAiAudioSpeechResponseMetadata speechResponseMetadata = new OpenAiAudioSpeechResponseMetadata(); + return speechResponseMetadata; + } + + public static OpenAiAudioSpeechResponseMetadata from(String result) { + Assert.notNull(result, "OpenAI speech must not be null"); + OpenAiAudioSpeechResponseMetadata speechResponseMetadata = new OpenAiAudioSpeechResponseMetadata(); + return speechResponseMetadata; + } + @Nullable public RateLimit getRateLimit() { RateLimit rateLimit = this.rateLimit; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java index 7fc7d1755b8..106c9d7264e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata.audio; import org.springframework.ai.audio.transcription.AudioTranscriptionResponseMetadata; import org.springframework.ai.chat.metadata.EmptyRateLimit; import org.springframework.ai.chat.metadata.RateLimit; -import org.springframework.ai.model.MutableResponseMetadata; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.metadata.OpenAiRateLimit; import org.springframework.lang.Nullable; @@ -34,20 +34,11 @@ */ public class OpenAiAudioTranscriptionResponseMetadata extends AudioTranscriptionResponseMetadata { - protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }"; - public static final OpenAiAudioTranscriptionResponseMetadata NULL = new OpenAiAudioTranscriptionResponseMetadata() { - }; - public static OpenAiAudioTranscriptionResponseMetadata from(OpenAiAudioApi.StructuredResponse result) { - Assert.notNull(result, "OpenAI Transcription must not be null"); - return new OpenAiAudioTranscriptionResponseMetadata(); - } + }; - public static OpenAiAudioTranscriptionResponseMetadata from(String result) { - Assert.notNull(result, "OpenAI Transcription must not be null"); - return new OpenAiAudioTranscriptionResponseMetadata(); - } + protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }"; @Nullable private RateLimit rateLimit; @@ -60,6 +51,16 @@ protected OpenAiAudioTranscriptionResponseMetadata(@Nullable OpenAiRateLimit rat this.rateLimit = rateLimit; } + public static OpenAiAudioTranscriptionResponseMetadata from(OpenAiAudioApi.StructuredResponse result) { + Assert.notNull(result, "OpenAI Transcription must not be null"); + return new OpenAiAudioTranscriptionResponseMetadata(); + } + + public static OpenAiAudioTranscriptionResponseMetadata from(String result) { + Assert.notNull(result, "OpenAI Transcription must not be null"); + return new OpenAiAudioTranscriptionResponseMetadata(); + } + @Nullable public RateLimit getRateLimit() { RateLimit rateLimit = this.rateLimit; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java index 47d3d5f2d59..5c6107c8e4f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata.support; /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java index f472ad06033..1d46556cc1e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata.support; import java.time.Duration; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index f91edff9616..74ca86bef95 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import java.util.List; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java index 7083d798021..f11613fa16c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import org.junit.jupiter.api.Test; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index 24be2e910c1..d9e6b6ca513 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiImageApi; -import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.ai.openai.api.OpenAiModerationApi; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/TranscriptionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/TranscriptionRequestTests.java index 96a95ba4eda..7c7c467e2d8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/TranscriptionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/TranscriptionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import org.junit.jupiter.api.Test; +import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest.GranularityType; -import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.core.io.DefaultResourceLoader; import static org.assertj.core.api.Assertions.assertThat; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java index abe5b954d01..2701a2fe8c8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.acme; import java.util.List; @@ -24,16 +25,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.openai.OpenAiChatModel; -import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.testutils.AbstractIT; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.reader.JsonReader; import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.vectorstore.SimpleVectorStore; @@ -65,23 +66,23 @@ public class AcmeIT extends AbstractIT { @Test void beanTest() { - assertThat(bikesResource).isNotNull(); - assertThat(embeddingModel).isNotNull(); - assertThat(chatModel).isNotNull(); + assertThat(this.bikesResource).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); + assertThat(this.chatModel).isNotNull(); } // @Test void acmeChain() { // Step 1 - load documents - JsonReader jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription", "description"); + JsonReader jsonReader = new JsonReader(this.bikesResource, "name", "price", "shortDescription", "description"); var textSplitter = new TokenTextSplitter(); // Step 2 - Create embeddings and save to vector store logger.info("Creating Embeddings..."); - VectorStore vectorStore = new SimpleVectorStore(embeddingModel); + VectorStore vectorStore = new SimpleVectorStore(this.embeddingModel); vectorStore.accept(textSplitter.apply(jsonReader.get())); @@ -108,7 +109,7 @@ void acmeChain() { logger.info("Asking AI generative to reply to question."); Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); logger.info("AI responded."); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); evaluateQuestionAndAnswer(userQuery, response, true); } @@ -119,7 +120,7 @@ private Message getSystemMessage(List similarDocuments) { .map(entry -> entry.getContent()) .collect(Collectors.joining(System.lineSeparator())); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemBikePrompt); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemBikePrompt); Message systemMessage = systemPromptTemplate.createMessage(Map.of("documents", documents)); return systemMessage; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java index e4399ec8bdd..4b54d809c7a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index 17a00227635..a07d400e1d4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.List; @@ -43,7 +44,7 @@ public class OpenAiApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = openAiApi.chatCompletionEntity( + ResponseEntity response = this.openAiApi.chatCompletionEntity( new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false)); assertThat(response).isNotNull(); @@ -53,7 +54,7 @@ void chatCompletionEntity() { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = openAiApi.chatCompletionStream( + Flux response = this.openAiApi.chatCompletionStream( new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true)); assertThat(response).isNotNull(); @@ -62,7 +63,7 @@ void chatCompletionStream() { @Test void embeddings() { - ResponseEntity> response = openAiApi + ResponseEntity> response = this.openAiApi .embeddings(new OpenAiApi.EmbeddingRequest("Hello world")); assertThat(response).isNotNull(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java index db41af1f0d4..88e5df176bc 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api.tool; import java.util.function.Function; @@ -28,16 +29,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -65,28 +71,25 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index f8d0f20316b..b224fd61ea9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -32,8 +32,8 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type; import org.springframework.http.ResponseEntity; @@ -54,6 +54,15 @@ public class OpenAiApiToolFunctionCallIT { OpenAiApi completionApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("null") @Test public void toolFunctionCall() { @@ -95,7 +104,7 @@ public void toolFunctionCall() { List.of(functionTool), ToolChoiceBuilder.AUTO); // List.of(functionTool), ToolChoiceBuilder.FUNCTION("getCurrentWeather")); - ResponseEntity chatCompletion = completionApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.completionApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -116,7 +125,7 @@ public void toolFunctionCall() { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, @@ -126,9 +135,10 @@ public void toolFunctionCall() { var functionResponseRequest = new ChatCompletionRequest(messages, "gpt-4o", 0.5); - ResponseEntity chatCompletion2 = completionApi.chatCompletionEntity(functionResponseRequest); + ResponseEntity chatCompletion2 = this.completionApi + .chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(chatCompletion2.getBody().choices()).isNotEmpty(); @@ -144,13 +154,4 @@ public void toolFunctionCall() { } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java index f774711a602..a5c4123a912 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.api; import java.io.File; @@ -23,10 +24,10 @@ import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest; -import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest; +import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice; import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse; +import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest; import org.springframework.ai.openai.api.OpenAiAudioApi.TranslationRequest; -import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice; import org.springframework.ai.openai.api.OpenAiAudioApi.TtsModel; import org.springframework.ai.openai.api.OpenAiAudioApi.WhisperModel; import org.springframework.util.FileCopyUtils; @@ -45,7 +46,7 @@ public class OpenAiAudioApiIT { @Test void speechTranscriptionAndTranslation() throws IOException { - byte[] speech = audioApi + byte[] speech = this.audioApi .createSpeech(SpeechRequest.builder() .withModel(TtsModel.TTS_1_HD.getValue()) .withInput("Hello, my name is Chris and I love Spring A.I.") @@ -57,7 +58,7 @@ void speechTranscriptionAndTranslation() throws IOException { FileCopyUtils.copy(speech, new File("target/speech.mp3")); - StructuredResponse translation = audioApi + StructuredResponse translation = this.audioApi .createTranslation( TranslationRequest.builder().withModel(WhisperModel.WHISPER_1.getValue()).withFile(speech).build(), StructuredResponse.class) @@ -65,7 +66,7 @@ void speechTranscriptionAndTranslation() throws IOException { assertThat(translation.text().replaceAll(",", "")).isEqualTo("Hello my name is Chris and I love Spring AI."); - StructuredResponse transcriptionEnglish = audioApi.createTranscription( + StructuredResponse transcriptionEnglish = this.audioApi.createTranscription( TranscriptionRequest.builder().withModel(WhisperModel.WHISPER_1.getValue()).withFile(speech).build(), StructuredResponse.class) .getBody(); @@ -73,7 +74,7 @@ void speechTranscriptionAndTranslation() throws IOException { assertThat(transcriptionEnglish.text().replaceAll(",", "")) .isEqualTo("Hello my name is Chris and I love Spring AI."); - StructuredResponse transcriptionDutch = audioApi + StructuredResponse transcriptionDutch = this.audioApi .createTranscription(TranscriptionRequest.builder().withFile(speech).withLanguage("nl").build(), StructuredResponse.class) .getBody(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java index 0ff96b259f1..780ab89e224 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,17 +16,18 @@ package org.springframework.ai.openai.audio.speech; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.openai.OpenAiAudioSpeechOptions; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.boot.test.context.SpringBootTest; -import reactor.core.publisher.Flux; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -38,7 +39,7 @@ class OpenAiSpeechModelIT extends AbstractIT { @Test void shouldSuccessfullyStreamAudioBytesForEmptyMessage() { - Flux response = speechModel.stream("Today is a wonderful day to build something people love!"); + Flux response = this.speechModel.stream("Today is a wonderful day to build something people love!"); assertThat(response).isNotNull(); assertThat(response.collectList().block()).isNotNull(); System.out.println(response.collectList().block()); @@ -46,7 +47,7 @@ void shouldSuccessfullyStreamAudioBytesForEmptyMessage() { @Test void shouldProduceAudioBytesDirectlyFromMessage() { - byte[] audioBytes = speechModel.call("Today is a wonderful day to build something people love!"); + byte[] audioBytes = this.speechModel.call("Today is a wonderful day to build something people love!"); assertThat(audioBytes).hasSizeGreaterThan(0); } @@ -61,7 +62,7 @@ void shouldGenerateNonEmptyMp3AudioFromSpeechPrompt() { .build(); SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - SpeechResponse response = speechModel.call(speechPrompt); + SpeechResponse response = this.speechModel.call(speechPrompt); byte[] audioBytes = response.getResult().getOutput(); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput()).isNotEmpty(); @@ -79,7 +80,7 @@ void speechRateLimitTest() { .build(); SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - SpeechResponse response = speechModel.call(speechPrompt); + SpeechResponse response = this.speechModel.call(speechPrompt); OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata(); assertThat(metadata).isNotNull(); assertThat(metadata.getRateLimit()).isNotNull(); @@ -100,7 +101,7 @@ void shouldStreamNonEmptyResponsesForValidSpeechPrompts() { SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - Flux responseFlux = speechModel.stream(speechPrompt); + Flux responseFlux = this.speechModel.stream(speechPrompt); assertThat(responseFlux).isNotNull(); List responses = responseFlux.collectList().block(); assertThat(responses).isNotNull(); @@ -110,4 +111,4 @@ void shouldStreamNonEmptyResponsesForValidSpeechPrompts() { }); } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelWithSpeechResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelWithSpeechResponseMetadataTests.java index 089c9c8240d..0371dd08b90 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelWithSpeechResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelWithSpeechResponseMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,8 +16,11 @@ package org.springframework.ai.openai.audio.speech; +import java.time.Duration; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.openai.OpenAiAudioSpeechModel; import org.springframework.ai.openai.OpenAiAudioSpeechOptions; import org.springframework.ai.openai.api.OpenAiAudioApi; @@ -34,12 +37,10 @@ import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.RestClient; -import java.time.Duration; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; /** @@ -48,10 +49,10 @@ @RestClientTest(OpenAiSpeechModelWithSpeechResponseMetadataTests.Config.class) public class OpenAiSpeechModelWithSpeechResponseMetadataTests { - private static String TEST_API_KEY = "sk-1234567890"; - private static final Float SPEED = 1.0f; + private static String TEST_API_KEY = "sk-1234567890"; + @Autowired private OpenAiAudioSpeechModel openAiSpeechClient; @@ -60,7 +61,7 @@ public class OpenAiSpeechModelWithSpeechResponseMetadataTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -77,7 +78,7 @@ void aiResponseContainsImageResponseMetadata() { SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - SpeechResponse response = openAiSpeechClient.call(speechPrompt); + SpeechResponse response = this.openAiSpeechClient.call(speechPrompt); byte[] audioBytes = response.getResult().getOutput(); assertThat(audioBytes).hasSizeGreaterThan(0); @@ -110,7 +111,7 @@ private void prepareMock() { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); httpHeaders.setContentType(MediaType.APPLICATION_OCTET_STREAM); - server.expect(requestTo("/v1/audio/speech")) + this.server.expect(requestTo("/v1/audio/speech")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess("Audio bytes as string", MediaType.APPLICATION_OCTET_STREAM).headers(httpHeaders)); @@ -132,4 +133,4 @@ public OpenAiAudioApi openAiAudioApi(RestClient.Builder builder) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java index f4a44d36a64..bf252291c33 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.transcription; import org.junit.jupiter.api.Test; @@ -44,8 +45,9 @@ void transcriptionTest() { .withResponseFormat(TranscriptResponseFormat.TEXT) .withTemperature(0f) .build(); - AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest); + AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, + transcriptionOptions); + AudioTranscriptionResponse response = this.transcriptionModel.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } @@ -60,8 +62,9 @@ void transcriptionTestWithOptions() { .withTemperature(0f) .withResponseFormat(responseFormat) .build(); - AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest); + AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, + transcriptionOptions); + AudioTranscriptionResponse response = this.transcriptionModel.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.java index a7749a8f021..a1b23b4a7bf 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.audio.transcription; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; -import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; +package org.springframework.ai.openai.audio.transcription; import java.time.Duration; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.audio.transcription.AudioTranscriptionMetadata; import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; @@ -46,6 +42,12 @@ import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.RestClient; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; + /** * @author Michael Lavelle */ @@ -62,7 +64,7 @@ public class OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -118,7 +120,7 @@ private void prepareMock() { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - server.expect(requestTo("/v1/audio/transcriptions")) + this.server.expect(requestTo("/v1/audio/transcriptions")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java index 4e701003591..1f93fe0f69b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.transcription; import org.junit.jupiter.api.Test; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/ActorsFilms.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/ActorsFilms.java index 80320186df0..1226618b694 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/ActorsFilms.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/ActorsFilms.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; import java.util.List; @@ -27,7 +28,7 @@ public ActorsFilms() { } public String getActor() { - return actor; + return this.actor; } public void setActor(String actor) { @@ -35,7 +36,7 @@ public void setActor(String actor) { } public List getMovies() { - return movies; + return this.movies; } public void setMovies(List movies) { @@ -44,7 +45,7 @@ public void setMovies(List movies) { @Override public String toString() { - return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}'; + return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}'; } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java index 8a393c8aafd..1fcb34af273 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - import java.net.MalformedURLException; import java.net.URL; import java.util.List; @@ -32,10 +29,12 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.model.Media; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.Media; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; @@ -44,7 +43,8 @@ import org.springframework.util.MimeTypeUtils; import org.springframework.util.MultiValueMap; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; /** * @author Christian Tzolov @@ -73,41 +73,42 @@ public class MessageTypeContentTests { @BeforeEach public void beforeEach() { - chatModel = new OpenAiChatModel(openAiApi); + this.chatModel = new OpenAiChatModel(this.openAiApi); } @Test public void systemMessageSimpleContentType() { - when(openAiApi.chatCompletionEntity(pomptCaptor.capture(), headersCaptor.capture())) + when(this.openAiApi.chatCompletionEntity(this.pomptCaptor.capture(), this.headersCaptor.capture())) .thenReturn(Mockito.mock(ResponseEntity.class)); - chatModel.call(new Prompt(List.of(new SystemMessage("test message")))); + this.chatModel.call(new Prompt(List.of(new SystemMessage("test message")))); - validateStringContent(pomptCaptor.getValue()); - assertThat(headersCaptor.getValue()).isEmpty(); + validateStringContent(this.pomptCaptor.getValue()); + assertThat(this.headersCaptor.getValue()).isEmpty(); } @Test public void userMessageSimpleContentType() { - when(openAiApi.chatCompletionEntity(pomptCaptor.capture(), headersCaptor.capture())) + when(this.openAiApi.chatCompletionEntity(this.pomptCaptor.capture(), this.headersCaptor.capture())) .thenReturn(Mockito.mock(ResponseEntity.class)); - chatModel.call(new Prompt(List.of(new UserMessage("test message")))); + this.chatModel.call(new Prompt(List.of(new UserMessage("test message")))); - validateStringContent(pomptCaptor.getValue()); + validateStringContent(this.pomptCaptor.getValue()); } @Test public void streamUserMessageSimpleContentType() { - when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse); + when(this.openAiApi.chatCompletionStream(this.pomptCaptor.capture(), this.headersCaptor.capture())) + .thenReturn(this.fluxResponse); - chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe(); + this.chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe(); - validateStringContent(pomptCaptor.getValue()); - assertThat(headersCaptor.getValue()).isEmpty(); + validateStringContent(this.pomptCaptor.getValue()); + assertThat(this.headersCaptor.getValue()).isEmpty(); } private void validateStringContent(ChatCompletionRequest chatCompletionRequest) { @@ -121,28 +122,29 @@ private void validateStringContent(ChatCompletionRequest chatCompletionRequest) @Test public void userMessageWithMediaType() throws MalformedURLException { - when(openAiApi.chatCompletionEntity(pomptCaptor.capture(), headersCaptor.capture())) + when(this.openAiApi.chatCompletionEntity(this.pomptCaptor.capture(), this.headersCaptor.capture())) .thenReturn(Mockito.mock(ResponseEntity.class)); URL mediaUrl = new URL("http://test"); - chatModel.call(new Prompt( + this.chatModel.call(new Prompt( List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl)))))); - validateComplexContent(pomptCaptor.getValue()); + validateComplexContent(this.pomptCaptor.getValue()); } @Test public void streamUserMessageWithMediaType() throws MalformedURLException { - when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse); + when(this.openAiApi.chatCompletionStream(this.pomptCaptor.capture(), this.headersCaptor.capture())) + .thenReturn(this.fluxResponse); URL mediaUrl = new URL("http://test"); - chatModel + this.chatModel .stream(new Prompt( List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl)))))) .subscribe(); - validateComplexContent(pomptCaptor.getValue()); + validateComplexContent(this.pomptCaptor.getValue()); } private void validateComplexContent(ChatCompletionRequest chatCompletionRequest) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java index 7bb2a98133e..cce923b87f0 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertThrows; +package org.springframework.ai.openai.chat; import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.openai.OpenAiChatModel; @@ -33,6 +32,9 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertThrows; + /** * @author Christian Tzolov */ diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java index 1c1087257fe..5caa848b964 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -38,13 +47,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.function.BiFunction; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -112,7 +114,7 @@ void functionCallTest(OpenAiChatOptions promptOptions) { List messages = new ArrayList<>(List.of(userMessage)); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -175,7 +177,7 @@ void streamFunctionCallTest(OpenAiChatOptions promptOptions) { List messages = new ArrayList<>(List.of(userMessage)); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -205,4 +207,4 @@ public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 66a1fd7b415..b12a10e61c8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat; import java.io.IOException; import java.net.URL; @@ -34,9 +33,10 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -47,6 +47,7 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiTestConfiguration; @@ -60,7 +61,7 @@ import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = OpenAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -75,10 +76,10 @@ public class OpenAiChatModelIT extends AbstractIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); @@ -88,16 +89,16 @@ void roleTest() { void testMessageHistory() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(), new UserMessage("Repeat the last assistant message."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @@ -111,7 +112,7 @@ void streamCompletenessTest() throws InterruptedException { StringBuilder answer = new StringBuilder(); CountDownLatch latch = new CountDownLatch(1); - Flux chatResponseFlux = streamingChatModel.stream(prompt).doOnNext(chatResponse -> { + Flux chatResponseFlux = this.streamingChatModel.stream(prompt).doOnNext(chatResponse -> { String responseContent = chatResponse.getResults().get(0).getOutput().getContent(); answer.append(responseContent); }).doOnComplete(() -> { @@ -133,7 +134,7 @@ void streamCompletenessTestWithChatResponse() throws InterruptedException { StringBuilder answer = new StringBuilder(); CountDownLatch latch = new CountDownLatch(1); - ChatClient chatClient = ChatClient.builder(openAiChatModel).build(); + ChatClient chatClient = ChatClient.builder(this.openAiChatModel).build(); Flux chatResponseFlux = chatClient.prompt(prompt) .stream() @@ -159,7 +160,7 @@ void ensureChatResponseAsContentDoesNotSwallowBlankSpace() throws InterruptedExc StringBuilder answer = new StringBuilder(); CountDownLatch latch = new CountDownLatch(1); - ChatClient chatClient = ChatClient.builder(openAiChatModel).build(); + ChatClient chatClient = ChatClient.builder(this.openAiChatModel).build(); Flux chatResponseFlux = chatClient.prompt(prompt) .stream() @@ -178,10 +179,10 @@ void ensureChatResponseAsContentDoesNotSwallowBlankSpace() throws InterruptedExc void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = streamingChatModel.stream(prompt); + Flux flux = this.streamingChatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -247,7 +248,7 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -266,14 +267,11 @@ void beanOutputConverter() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -286,7 +284,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -307,7 +305,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -339,7 +337,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -364,7 +362,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = streamingChatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -390,7 +388,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -406,7 +404,7 @@ void multiModalityImageUrl(String modelName) throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -421,7 +419,7 @@ void streamingMultiModalityImageUrl() throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = streamingChatModel.stream(new Prompt(List.of(userMessage), + Flux response = this.streamingChatModel.stream(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build())); String content = response.collectList() @@ -441,7 +439,7 @@ void streamingMultiModalityImageUrl() throws IOException { void validateCallResponseMetadata() { String model = OpenAiApi.ChatModel.GPT_3_5_TURBO.getName(); // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -456,4 +454,8 @@ void validateCallResponseMetadata() { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } -} \ No newline at end of file + record ActorsFilmsRecord(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java index 362135eb4f4..36edf6952be 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; -import reactor.core.publisher.Flux; - import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -38,9 +42,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; @@ -62,7 +63,7 @@ public class OpenAiChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -80,7 +81,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -104,7 +105,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -125,7 +126,7 @@ void observationForStreamingChatOperation() { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java index c4e1198ee44..dc43e4e7f7b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat; import java.util.ArrayList; import java.util.List; @@ -25,10 +24,16 @@ import java.util.function.Function; import java.util.stream.Collectors; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -37,8 +42,8 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.ToolCallHelper; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.ToolCallHelper; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; @@ -49,12 +54,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.util.CollectionUtils; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.micrometer.observation.ObservationRegistry; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = OpenAiChatModelProxyToolCallsIT.Config.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -64,6 +64,24 @@ class OpenAiChatModelProxyToolCallsIT { private static final String DEFAULT_MODEL = "gpt-4o-mini"; + FunctionCallback functionDefinition = new ToolCallHelper.FunctionDefinition("getWeatherInLocation", + "Get the weather in location", """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "unit"] + } + """); + @Autowired private OpenAiChatModel chatModel; @@ -71,6 +89,16 @@ class OpenAiChatModelProxyToolCallsIT { // to help to implement the function call handling logic on the client side. private ToolCallHelper toolCallHelper = new ToolCallHelper(); + @SuppressWarnings("unchecked") + private static Map getFunctionArguments(String functionArguments) { + try { + return new ObjectMapper().readValue(functionArguments, Map.class); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + // Function which will be called by the AI model. private String getWeatherInLocation(String location, String unit) { @@ -89,31 +117,13 @@ else if (location.contains("San Francisco")) { return String.format("The weather in %s is %s%s", location, temperature, unit); } - FunctionCallback functionDefinition = new ToolCallHelper.FunctionDefinition("getWeatherInLocation", - "Get the weather in location", """ - { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["C", "F"] - } - }, - "required": ["location", "unit"] - } - """); - @Test void functionCall() throws JsonMappingException, JsonProcessingException { List messages = List .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(this.functionDefinition)).build(); var prompt = new Prompt(messages, promptOptions); @@ -123,13 +133,13 @@ void functionCall() throws JsonMappingException, JsonProcessingException { do { - chatResponse = chatModel.call(prompt); + chatResponse = this.chatModel.call(prompt); // We will have to convert the chatResponse into OpenAI assistant message. // Note that the tool call check could be platform specific because the finish // reasons. - isToolCall = toolCallHelper.isToolCall(chatResponse, + isToolCall = this.toolCallHelper.isToolCall(chatResponse, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name())); @@ -166,8 +176,8 @@ void functionCall() throws JsonMappingException, JsonProcessingException { ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - List toolCallConversation = toolCallHelper.buildToolCallConversation(prompt.getInstructions(), - assistantMessage, toolMessageResponse); + List toolCallConversation = this.toolCallHelper + .buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); assertThat(toolCallConversation).isNotEmpty(); @@ -187,7 +197,7 @@ void functionStream() throws JsonMappingException, JsonProcessingException { List messages = List .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(this.functionDefinition)).build(); var prompt = new Prompt(messages, promptOptions); @@ -222,11 +232,11 @@ void functionStream() throws JsonMappingException, JsonProcessingException { private Flux processToolCall(Prompt prompt, Set finishReasons, Function customFunction) { - Flux chatResponses = chatModel.stream(prompt); + Flux chatResponses = this.chatModel.stream(prompt); return chatResponses.flatMap(chatResponse -> { - boolean isToolCall = toolCallHelper.isToolCall(chatResponse, finishReasons); + boolean isToolCall = this.toolCallHelper.isToolCall(chatResponse, finishReasons); if (isToolCall) { @@ -251,8 +261,8 @@ private Flux processToolCall(Prompt prompt, Set finishReas ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - List toolCallConversation = toolCallHelper.buildToolCallConversation(prompt.getInstructions(), - assistantMessage, toolMessageResponse); + List toolCallConversation = this.toolCallHelper + .buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); assertThat(toolCallConversation).isNotEmpty(); @@ -271,11 +281,11 @@ void functionCall2() throws JsonMappingException, JsonProcessingException { List messages = List .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(this.functionDefinition)).build(); var prompt = new Prompt(messages, promptOptions); - ChatResponse chatResponse = toolCallHelper.processCall(chatModel, prompt, + ChatResponse chatResponse = this.toolCallHelper.processCall(this.chatModel, prompt, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name()), toolCall -> { @@ -305,11 +315,11 @@ void functionStream2() throws JsonMappingException, JsonProcessingException { List messages = List .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(this.functionDefinition)).build(); var prompt = new Prompt(messages, promptOptions); - Flux responses = toolCallHelper.processStream(chatModel, prompt, + Flux responses = this.toolCallHelper.processStream(this.chatModel, prompt, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name()), toolCall -> { @@ -340,16 +350,6 @@ void functionStream2() throws JsonMappingException, JsonProcessingException { } - @SuppressWarnings("unchecked") - private static Map getFunctionArguments(String functionArguments) { - try { - return new ObjectMapper().readValue(functionArguments, Map.class); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - @SpringBootConfiguration static class Config { @@ -369,4 +369,4 @@ public OpenAiChatModel openAiClient(OpenAiApi openAiApi, List } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java index f1c5859d9cd..63bf6a88a91 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JacksonException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.converter.BeanOutputConverter; @@ -34,12 +40,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JacksonException; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -48,11 +49,23 @@ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiChatModelResponseFormatIT { + private static ObjectMapper MAPPER = new ObjectMapper().enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS); + private final Logger logger = LoggerFactory.getLogger(getClass()); @Autowired private OpenAiChatModel openAiChatModel; + public static boolean isValidJson(String json) { + try { + MAPPER.readTree(json); + } + catch (JacksonException e) { + return false; + } + return true; + } + @Test void jsonObject() throws JsonMappingException, JsonProcessingException { @@ -76,7 +89,7 @@ void jsonObject() throws JsonMappingException, JsonProcessingException { String content = response.getResult().getOutput().getContent(); - logger.info("Response content: {}", content); + this.logger.info("Response content: {}", content); assertThat(isValidJson(content)).isTrue(); } @@ -119,7 +132,7 @@ void jsonSchema() throws JsonMappingException, JsonProcessingException { String content = response.getResult().getOutput().getContent(); - logger.info("Response content: {}", content); + this.logger.info("Response content: {}", content); assertThat(isValidJson(content)).isTrue(); } @@ -134,8 +147,11 @@ record Steps(@JsonProperty(required = true, value = "items") Items[] items) { record Items(@JsonProperty(required = true, value = "explanation") String explanation, @JsonProperty(required = true, value = "output") String output) { + } + } + } var outputConverter = new BeanOutputConverter<>(MathReasoning.class); @@ -156,7 +172,7 @@ record Items(@JsonProperty(required = true, value = "explanation") String explan String content = response.getResult().getOutput().getContent(); - logger.info("Response content: {}", content); + this.logger.info("Response content: {}", content); MathReasoning mathReasoning = outputConverter.convert(content); @@ -165,18 +181,6 @@ record Items(@JsonProperty(required = true, value = "explanation") String explan assertThat(isValidJson(content)).isTrue(); } - private static ObjectMapper MAPPER = new ObjectMapper().enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS); - - public static boolean isValidJson(String json) { - try { - MAPPER.readTree(json); - } - catch (JacksonException e) { - return false; - } - return true; - } - @SpringBootConfiguration static class Config { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelTypeReferenceBeanOutputConverterIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelTypeReferenceBeanOutputConverterIT.java index d222a64c3bf..443c65c7980 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelTypeReferenceBeanOutputConverterIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelTypeReferenceBeanOutputConverterIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; import java.util.List; @@ -24,9 +25,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; @@ -44,14 +45,12 @@ class OpenAiChatModelTypeReferenceBeanOutputConverterIT extends AbstractIT { private static final Logger logger = LoggerFactory .getLogger(OpenAiChatModelTypeReferenceBeanOutputConverterIT.class); - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void typeRefOutputConverterRecords() { BeanOutputConverter> outputConverter = new BeanOutputConverter<>( new ParameterizedTypeReference>() { + }); String format = outputConverter.getFormat(); @@ -61,7 +60,7 @@ void typeRefOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); List actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -77,6 +76,7 @@ void typeRefStreamOutputConverterRecords() { BeanOutputConverter> outputConverter = new BeanOutputConverter<>( new ParameterizedTypeReference>() { + }); String format = outputConverter.getFormat(); @@ -87,7 +87,7 @@ void typeRefStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -106,4 +106,8 @@ void typeRefStreamOutputConverterRecords() { assertThat(actorsFilms.get(1).movies()).hasSize(5); } -} \ No newline at end of file + record ActorsFilmsRecord(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java index 4afeca4762e..2e6f6ddb687 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; import java.time.Duration; @@ -20,16 +21,16 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; @@ -65,7 +66,7 @@ public class OpenAiChatModelWithChatResponseMetadataTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -132,7 +133,7 @@ private void prepareMock() { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - server.expect(requestTo("/v1/chat/completions")) + this.server.expect(requestTo("/v1/chat/completions")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java index 4a4e30719c9..10645ee562a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; @@ -30,11 +37,6 @@ import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; @@ -48,7 +50,9 @@ public class OpenAiCompatibleChatModelIT { static OpenAiChatOptions forModelName(String modelName) { return OpenAiChatOptions.builder().withModel(modelName).build(); - }; + } + + ; static Stream openAiCompatibleApis() { Stream.Builder builder = Stream.builder(); @@ -72,7 +76,7 @@ static Stream openAiCompatibleApis() { @ParameterizedTest @MethodSource("openAiCompatibleApis") void chatCompletion(ChatModel chatModel) { - Prompt prompt = new Prompt(conversation); + Prompt prompt = new Prompt(this.conversation); ChatResponse response = chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); @@ -82,7 +86,7 @@ void chatCompletion(ChatModel chatModel) { @ParameterizedTest @MethodSource("openAiCompatibleApis") void streamCompletion(StreamingChatModel streamingChatModel) { - Prompt prompt = new Prompt(conversation); + Prompt prompt = new Prompt(this.conversation); Flux flux = streamingChatModel.stream(prompt); List responses = flux.collectList().block(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index 26635b2fde9..97c38b13b9e 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,6 @@ package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.List; import java.util.Map; import java.util.function.Function; @@ -28,11 +26,13 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.openai.OpenAiChatModel; @@ -48,7 +48,7 @@ import org.springframework.context.annotation.Description; import org.springframework.core.ParameterizedTypeReference; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -59,55 +59,12 @@ public class OpenAiPaymentTransactionIT { private final static Logger logger = LoggerFactory.getLogger(OpenAiPaymentTransactionIT.class); + private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), + new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); + @Autowired ChatClient chatClient; - record TransactionStatusResponse(String id, String status) { - } - - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - - advisedRequest = this.before(advisedRequest); - - AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); - - this.observeAfter(advisedResponse); - - return advisedResponse; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "paymentStatus", "paymentStatuses" }) public void transactionPaymentStatuses(String functionName) { @@ -119,6 +76,7 @@ public void transactionPaymentStatuses(String functionName) { """) .call() .entity(new ParameterizedTypeReference>() { + }); logger.info("" + content); @@ -138,6 +96,7 @@ public void transactionPaymentStatuses(String functionName) { public void streamingPaymentStatuses(String functionName) { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference>() { + }); Flux flux = this.chatClient.prompt() @@ -166,20 +125,68 @@ public void streamingPaymentStatuses(String functionName) { assertThat(structure.get(2).status()).isEqualTo("rejected"); } + record TransactionStatusResponse(String id, String status) { + + } + + private static class LoggingAdvisor implements CallAroundAdvisor { + + private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); + + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + advisedRequest = this.before(advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + + this.observeAfter(advisedResponse); + + return advisedResponse; + } + + private AdvisedRequest before(AdvisedRequest request) { + this.logger.info("System text: \n" + request.systemText()); + this.logger.info("System params: " + request.systemParams()); + this.logger.info("User text: \n" + request.userText()); + this.logger.info("User params:" + request.userParams()); + this.logger.info("Function names: " + request.functionNames()); + + this.logger.info("Options: " + request.chatOptions().toString()); + + return request; + } + + private void observeAfter(AdvisedResponse advisedResponse) { + this.logger.info("Response: " + advisedResponse.response()); + } + + } + record Transaction(String id) { + } record Status(String name) { + } record Transactions(List transactions) { + } record Statuses(List statuses) { - } - private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), - new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); + } @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index bae33e60c83..9b9e28e4383 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; import java.util.List; @@ -26,6 +27,8 @@ import org.mockito.junit.jupiter.MockitoExtension; import reactor.core.publisher.Flux; +import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; +import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.image.ImageMessage; @@ -56,8 +59,6 @@ import org.springframework.ai.openai.api.OpenAiImageApi.Data; import org.springframework.ai.openai.api.OpenAiImageApi.OpenAiImageRequest; import org.springframework.ai.openai.api.OpenAiImageApi.OpenAiImageResponse; -import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; -import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; import org.springframework.core.io.ClassPathResource; @@ -69,8 +70,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.when; /** @@ -80,25 +81,6 @@ @ExtendWith(MockitoExtension.class) public class OpenAiRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -119,20 +101,22 @@ public void onError(RetryContext context, RetryCallback @BeforeEach public void beforeEach() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); - - chatModel = new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().build(), null, retryTemplate); - embeddingModel = new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, - OpenAiEmbeddingOptions.builder().build(), retryTemplate); - audioTranscriptionModel = new OpenAiAudioTranscriptionModel(openAiAudioApi, + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); + + this.chatModel = new OpenAiChatModel(this.openAiApi, OpenAiChatOptions.builder().build(), null, + this.retryTemplate); + this.embeddingModel = new OpenAiEmbeddingModel(this.openAiApi, MetadataMode.EMBED, + OpenAiEmbeddingOptions.builder().build(), this.retryTemplate); + this.audioTranscriptionModel = new OpenAiAudioTranscriptionModel(this.openAiAudioApi, OpenAiAudioTranscriptionOptions.builder() .withModel("model") .withResponseFormat(TranscriptResponseFormat.JSON) .build(), - retryTemplate); - imageModel = new OpenAiImageModel(openAiImageApi, OpenAiImageOptions.builder().build(), retryTemplate); + this.retryTemplate); + this.imageModel = new OpenAiImageModel(this.openAiImageApi, OpenAiImageOptions.builder().build(), + this.retryTemplate); } @Test @@ -143,24 +127,24 @@ public void openAiChatTransientError() { ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, new OpenAiApi.Usage(10, 10, 10)); - when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) + when(this.openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void openAiChatNonTransientError() { - when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) + when(this.openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -172,25 +156,25 @@ public void openAiChatStreamTransientError() { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, null, null); - when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) + when(this.openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test @Disabled("Currently stream() does not implmement retry") public void openAiChatStreamNonTransientError() { - when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) + when(this.openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).subscribe()); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).subscribe()); } @Test @@ -199,23 +183,25 @@ public void openAiEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new OpenAiApi.Usage(10, 10, 10)); - when(openAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new TransientAiException("Transient Error 1")) + when(this.openAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingModel + var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void openAiEmbeddingNonTransientError() { - when(openAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingModel + when(this.openAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } @@ -224,25 +210,25 @@ public void openAiAudioTranscriptionTransientError() { var expectedResponse = new StructuredResponse("nl", 6.7f, "Transcription Text", List.of(), List.of()); - when(openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) + when(this.openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); - AudioTranscriptionResponse result = audioTranscriptionModel + AudioTranscriptionResponse result = this.audioTranscriptionModel .call(new AudioTranscriptionPrompt(new ClassPathResource("speech/jfk.flac"))); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(expectedResponse.text()); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void openAiAudioTranscriptionNonTransientError() { - when(openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) + when(this.openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) .thenThrow(new RuntimeException("Transient Error 1")); - assertThrows(RuntimeException.class, () -> audioTranscriptionModel + assertThrows(RuntimeException.class, () -> this.audioTranscriptionModel .call(new AudioTranscriptionPrompt(new ClassPathResource("speech/jfk.flac")))); } @@ -251,25 +237,44 @@ public void openAiImageTransientError() { var expectedResponse = new OpenAiImageResponse(678l, List.of(new Data("url678", "b64", "prompt"))); - when(openAiImageApi.createImage(isA(OpenAiImageRequest.class))) + when(this.openAiImageApi.createImage(isA(OpenAiImageRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); - var result = imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void openAiImageNonTransientError() { - when(openAiImageApi.createImage(isA(OpenAiImageRequest.class))) + when(this.openAiImageApi.createImage(isA(OpenAiImageRequest.class))) .thenThrow(new RuntimeException("Transient Error 1")); assertThrows(RuntimeException.class, - () -> imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index 1af4b206502..6a35872525d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat.client; import java.io.IOException; @@ -62,9 +63,6 @@ class OpenAiChatClientIT extends AbstractIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { - } - @Test @Disabled("Although the Re2 advisor improves the response correctness it is not always guarantied to work.") void re2() { @@ -79,12 +77,12 @@ void re2() { """; // @formatter:off - ChatClient chatClient = ChatClient.builder(chatModel) + ChatClient chatClient = ChatClient.builder(this.chatModel) .defaultOptions(OpenAiChatOptions.builder() .withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build()) .defaultUser(REASON_QUESTION) .build(); - + String response = chatClient.prompt() .advisors(new ReReadingAdvisor()) .call() @@ -101,9 +99,9 @@ void re2() { void call() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .advisors(new SimpleLoggerAdvisor()) - .system(s -> s.text(systemTextResource) + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -119,7 +117,7 @@ void call() { @Test void listOutputConverterString() { // @formatter:off - List collection = ChatClient.create(chatModel).prompt() + List collection = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -134,7 +132,7 @@ void listOutputConverterString() { void listOutputConverterBean() { // @formatter:off - List actorsFilms = ChatClient.create(chatModel).prompt() + List actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.") .call() .entity(new ParameterizedTypeReference>() { @@ -151,7 +149,7 @@ void customOutputConverter() { var toStringListConverter = new ListOutputConverter(new DefaultConversionService()); // @formatter:off - List flavors = ChatClient.create(chatModel).prompt() + List flavors = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -166,7 +164,7 @@ void customOutputConverter() { @Test void mapOutputConverter() { // @formatter:off - Map result = ChatClient.create(chatModel).prompt() + Map result = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("Provide me a List of {subject}") .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'")) .call() @@ -181,7 +179,7 @@ void mapOutputConverter() { void beanOutputConverter() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography for a random actor.") .call() .entity(ActorsFilms.class); @@ -195,7 +193,7 @@ void beanOutputConverter() { void beanOutputConverterRecords() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks.") .call() .entity(ActorsFilms.class); @@ -212,7 +210,7 @@ void beanStreamOutputConverterRecords() { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); // @formatter:off - Flux chatResponse = ChatClient.create(chatModel) + Flux chatResponse = ChatClient.create(this.chatModel) .prompt() .options(OpenAiChatOptions.builder().withStreamUsage(true).build()) .advisors(new SimpleLoggerAdvisor()) @@ -246,7 +244,7 @@ void beanStreamOutputConverterRecords() { void functionCallTest() { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .call() @@ -262,7 +260,7 @@ void functionCallTest() { void defaultFunctionCallTest() { // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() @@ -278,7 +276,7 @@ void defaultFunctionCallTest() { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .stream() @@ -296,7 +294,7 @@ void streamFunctionCallTest() { void multiModalityEmbeddedImage(String modelName) throws IOException { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png"))) @@ -317,7 +315,7 @@ void multiModalityImageUrl(String modelName) throws IOException { URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to .options(OpenAiChatOptions.builder().withModel(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) @@ -337,7 +335,7 @@ void streamingMultiModalityImageUrl() throws IOException { URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()) .build()) .user(u -> u.text("Explain what do you see on this picture?") @@ -353,4 +351,8 @@ void streamingMultiModalityImageUrl() throws IOException { assertThat(content).containsAnyOf("bowl", "basket"); } -} \ No newline at end of file + record ActorsFilms(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java index 44ab8becf42..e8a4fbb17c5 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.client; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.client; import java.lang.reflect.Method; import java.util.List; @@ -28,6 +27,8 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.openai.OpenAiTestConfiguration; @@ -40,7 +41,7 @@ import org.springframework.core.io.Resource; import org.springframework.test.context.ActiveProfiles; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = OpenAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -52,13 +53,21 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { + public static Function createFunction(Object obj, Method method) { + return (T t) -> { + try { + return (R) method.invoke(obj, t); + } + catch (Exception e) { + throw new RuntimeException(e); + } + }; } @Test void turnFunctionsOnAndOffTest() { - var chatClientBuilder = ChatClient.builder(chatModel); + var chatClientBuilder = ChatClient.builder(this.chatModel); // @formatter:off String response = chatClientBuilder.build().prompt() @@ -100,7 +109,7 @@ void turnFunctionsOnAndOffTest() { void defaultFunctionCallTest() { // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() @@ -139,7 +148,7 @@ else if (request.location().contains("San Francisco")) { }; // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .defaultToolContext(Map.of("sessionId", "123")) @@ -179,7 +188,7 @@ else if (request.location().contains("San Francisco")) { }; // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() @@ -197,7 +206,7 @@ else if (request.location().contains("San Francisco")) { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .stream() @@ -214,7 +223,7 @@ void streamFunctionCallTest() { @Test void functionCallWithExplicitInputType() throws NoSuchMethodException { - var chatClient = ChatClient.create(chatModel); + var chatClient = ChatClient.create(this.chatModel); Method currentTemp = MyFunction.class.getMethod("getCurrentTemp", MyFunction.Req.class); @@ -232,26 +241,20 @@ void functionCallWithExplicitInputType() throws NoSuchMethodException { assertThat(content).contains("23"); } - public static Function createFunction(Object obj, Method method) { - return (T t) -> { - try { - return (R) method.invoke(obj, t); - } - catch (Exception e) { - throw new RuntimeException(e); - } - }; + record ActorsFilms(String actor, List movies) { + } public static class MyFunction { - public record Req(String city) { - } - public String getCurrentTemp(Req req) { return "23"; } + public record Req(String city) { + + } + } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java index 47d3d2af7df..0ad524689b3 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat.client; import java.util.HashMap; import java.util.Map; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; @@ -25,8 +28,6 @@ import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; -import reactor.core.publisher.Flux; - /** * Drawing inspiration from the human strategy of re-reading, this advisor implements a * re-reading strategy for LLM reasoning, dubbed RE2, to enhance understanding in the @@ -91,4 +92,4 @@ public ReReadingAdvisor withOrder(int order) { return this; } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java index c3cfb550a8c..00e12633498 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.proxy; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.proxy; import java.io.IOException; import java.net.URL; @@ -32,6 +31,8 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -62,7 +63,7 @@ import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = GroqWithOpenAiChatModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+") @@ -85,10 +86,10 @@ class GroqWithOpenAiChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -97,10 +98,10 @@ void roleTest() { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = chatModel.stream(prompt); + Flux flux = this.chatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -167,7 +168,7 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -186,15 +187,12 @@ void beanOutputConverter() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.getActor()).isNotEmpty(); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -207,7 +205,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -228,7 +226,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -259,7 +257,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -282,7 +280,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -307,7 +305,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -324,7 +322,7 @@ void multiModalityImageUrl(String modelName) throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -340,7 +338,7 @@ void streamingMultiModalityImageUrl() throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = chatModel.stream(new Prompt(List.of(userMessage))); + Flux response = this.chatModel.stream(new Prompt(List.of(userMessage))); String content = response.collectList() .block() @@ -359,7 +357,7 @@ void streamingMultiModalityImageUrl() throws IOException { @ValueSource(strings = { "llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it" }) void validateCallResponseMetadata(String model) { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -374,6 +372,10 @@ void validateCallResponseMetadata(String model) { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration static class Config { @@ -389,4 +391,4 @@ public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java index 9d7a66e954e..51d29144d1b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.proxy; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.proxy; import java.io.IOException; import java.net.URL; @@ -32,9 +31,10 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -45,6 +45,7 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; @@ -61,7 +62,7 @@ import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = MistralWithOpenAiChatModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") @@ -83,10 +84,10 @@ class MistralWithOpenAiChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -95,10 +96,10 @@ void roleTest() { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = chatModel.stream(prompt); + Flux flux = this.chatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -165,7 +166,7 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -184,15 +185,12 @@ void beanOutputConverter() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.getActor()).isNotEmpty(); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -205,7 +203,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -226,7 +224,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -260,7 +258,7 @@ void functionCallTest(String modelName) { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -285,7 +283,7 @@ void streamFunctionCallTest(String modelName) { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -310,7 +308,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -327,7 +325,7 @@ void multiModalityImageUrl(String modelName) throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -343,7 +341,7 @@ void streamingMultiModalityImageUrl() throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = chatModel.stream(new Prompt(List.of(userMessage))); + Flux response = this.chatModel.stream(new Prompt(List.of(userMessage))); String content = response.collectList() .block() @@ -363,7 +361,7 @@ void streamingMultiModalityImageUrl() throws IOException { "open-mixtral-8x22b" }) void validateCallResponseMetadata(String model) { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -378,6 +376,10 @@ void validateCallResponseMetadata(String model) { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration static class Config { @@ -393,4 +395,4 @@ public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java index 4b713f2208f..4c5ad7bc376 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.proxy; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.proxy; import java.util.ArrayList; import java.util.Arrays; @@ -28,6 +27,8 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -54,7 +55,7 @@ import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.Resource; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -81,10 +82,10 @@ class NvidiaWithOpenAiChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -93,10 +94,10 @@ void roleTest() { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = chatModel.stream(prompt); + Flux flux = this.chatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -162,7 +163,7 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -181,15 +182,12 @@ void beanOutputConverter() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.getActor()).isNotEmpty(); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -202,7 +200,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -223,7 +221,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -255,7 +253,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -278,7 +276,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -296,7 +294,7 @@ void streamFunctionCallTest() { @Test void validateCallResponseMetadata() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(DEFAULT_NVIDIA_MODEL).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -311,6 +309,10 @@ void validateCallResponseMetadata() { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration static class Config { @@ -327,4 +329,4 @@ public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java index a723f9dae23..523a3fa6481 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.proxy; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.proxy; import java.io.IOException; import java.net.URL; @@ -32,6 +31,11 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -61,11 +65,8 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.ollama.OllamaContainer; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @Disabled("For manual smoke testing only.") @Testcontainers @@ -81,6 +82,12 @@ class OllamaWithOpenAiChatModelIT { static String baseUrl = "http://localhost:11434"; + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Autowired + private OpenAiChatModel chatModel; + @BeforeAll public static void beforeAll() throws IOException, InterruptedException { logger.info("Start pulling the '" + DEFAULT_OLLAMA_MODEL + " ' generative ... would take several minutes ..."); @@ -92,20 +99,14 @@ public static void beforeAll() throws IOException, InterruptedException { baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); } - @Value("classpath:/prompts/system-message.st") - private Resource systemResource; - - @Autowired - private OpenAiChatModel chatModel; - @Test void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -114,10 +115,10 @@ void roleTest() { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = chatModel.stream(prompt); + Flux flux = this.chatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -184,7 +185,7 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -203,15 +204,12 @@ void beanOutputConverter() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.getActor()).isNotEmpty(); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -224,7 +222,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -245,7 +243,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -278,7 +276,7 @@ void functionCallTest(String modelName) { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -302,7 +300,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -326,7 +324,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -343,7 +341,7 @@ void multiModalityImageUrl(String modelName) throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -360,7 +358,7 @@ void streamingMultiModalityImageUrl(String modelName) throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = chatModel + Flux response = this.chatModel .stream(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); String content = response.collectList() @@ -380,7 +378,7 @@ void streamingMultiModalityImageUrl(String modelName) throws IOException { @ValueSource(strings = { "mistral" }) void validateCallResponseMetadata(String model) { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -395,6 +393,10 @@ void validateCallResponseMetadata(String model) { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration static class Config { @@ -410,4 +412,4 @@ public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java index 079990e1a0c..ae3f4dbb0b1 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.embedding; -import org.junit.jupiter.api.Test; +import java.nio.charset.StandardCharsets; +import java.util.List; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; @@ -33,9 +36,6 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; -import java.nio.charset.StandardCharsets; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -50,9 +50,9 @@ class EmbeddingIT extends AbstractIT { @Test void defaultEmbedding() { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); @@ -60,12 +60,12 @@ void defaultEmbedding() { assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(2); assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(2); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } @Test void embeddingBatchDocuments() throws Exception { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); List embedded = this.embeddingModel.embed( List.of(new Document("Hello world"), new Document("Hello Spring"), new Document("Hello Spring AI!")), OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(), @@ -76,10 +76,10 @@ void embeddingBatchDocuments() throws Exception { @Test void embeddingBatchDocumentsThatExceedTheLimit() throws Exception { - assertThat(embeddingModel).isNotNull(); - String contentAsString = resource.getContentAsString(StandardCharsets.UTF_8); + assertThat(this.embeddingModel).isNotNull(); + String contentAsString = this.resource.getContentAsString(StandardCharsets.UTF_8); assertThatThrownBy(() -> { - embeddingModel.embed(List.of(new Document("Hello World"), new Document(contentAsString)), + this.embeddingModel.embed(List.of(new Document("Hello World"), new Document(contentAsString)), OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(), new TokenCountBatchingStrategy()); }).isInstanceOf(IllegalArgumentException.class); @@ -88,7 +88,7 @@ void embeddingBatchDocumentsThatExceedTheLimit() throws Exception { @Test void embedding3Large() { - EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), OpenAiEmbeddingOptions.builder().withModel("text-embedding-3-large").build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); @@ -103,7 +103,7 @@ void embedding3Large() { @Test void textEmbeddingAda002() { - EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), OpenAiEmbeddingOptions.builder().withModel("text-embedding-3-small").build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java index 6a3e4d0367f..f5f0b046288 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.embedding; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -35,8 +39,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -66,13 +68,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelIT.java index 6f0ea968a9d..cd500e2d7f7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.image; import org.assertj.core.api.Assertions; @@ -44,7 +45,7 @@ void imageAsUrlTest() { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java index 0a1d3087d3e..146504420f7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.image; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; @@ -66,10 +68,10 @@ void observationForImageOperation() { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelWithImageResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelWithImageResponseMetadataTests.java index a3b3160f195..47f8b60e516 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelWithImageResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelWithImageResponseMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.image; +import java.util.List; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; @@ -34,12 +38,10 @@ import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.RestClient; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; /** @@ -60,7 +62,7 @@ public class OpenAiImageModelWithImageResponseMetadataTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -102,7 +104,7 @@ private void prepareMock() { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - server.expect(requestTo("v1/images/generations")) + this.server.expect(requestTo("v1/images/generations")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java index 1c7c53e0d25..65a97c1c6b3 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java @@ -1,6 +1,23 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai.metadata; import org.junit.jupiter.api.Test; + import org.springframework.ai.openai.api.OpenAiApi; import static org.assertj.core.api.Assertions.assertThat; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractorTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractorTests.java index 050b05a530d..dc2aff589b0 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractorTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.metadata.support; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.metadata.support; import java.time.Duration; @@ -23,6 +22,8 @@ import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor.DurationFormatter; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit Tests for {@link OpenAiHttpResponseHeadersInterceptor}. * diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelIT.java index ed0658862c9..f00aa781319 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.moderation; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.moderation.*; + +import org.springframework.ai.moderation.Categories; +import org.springframework.ai.moderation.CategoryScores; +import org.springframework.ai.moderation.Moderation; +import org.springframework.ai.moderation.ModerationOptionsBuilder; +import org.springframework.ai.moderation.ModerationPrompt; +import org.springframework.ai.moderation.ModerationResponse; +import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.boot.test.context.SpringBootTest; @@ -42,7 +50,7 @@ void moderationAsUrlTestPositive() { ModerationPrompt moderationPrompt = new ModerationPrompt(instructions, options); - ModerationResponse moderationResponse = openAiModerationModel.call(moderationPrompt); + ModerationResponse moderationResponse = this.openAiModerationModel.call(moderationPrompt); assertThat(moderationResponse.getResults()).hasSize(1); @@ -96,7 +104,7 @@ void moderationAsUrlTestNegative() { ModerationPrompt moderationPrompt = new ModerationPrompt(instructions, options); - ModerationResponse moderationResponse = openAiModerationModel.call(moderationPrompt); + ModerationResponse moderationResponse = this.openAiModerationModel.call(moderationPrompt); assertThat(moderationResponse.getResults()).hasSize(1); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelTests.java index 9d2fd687760..30d2a9c2700 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,18 @@ package org.springframework.ai.openai.moderation; +import java.util.List; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.springframework.ai.moderation.*; + +import org.springframework.ai.moderation.Categories; +import org.springframework.ai.moderation.CategoryScores; +import org.springframework.ai.moderation.Generation; +import org.springframework.ai.moderation.Moderation; +import org.springframework.ai.moderation.ModerationPrompt; +import org.springframework.ai.moderation.ModerationResponse; +import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.openai.OpenAiModerationModel; import org.springframework.ai.openai.api.OpenAiModerationApi; import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; @@ -33,10 +42,10 @@ import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.RestClient; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; /** @@ -56,7 +65,7 @@ public class OpenAiModerationModelTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -121,7 +130,7 @@ private void prepareMock() { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - server.expect(requestTo("v1/moderations")) + this.server.expect(requestTo("v1/moderations")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index 944852435ca..09d7c42f70f 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,14 +22,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; - import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.image.ImageModel; import org.springframework.ai.openai.OpenAiAudioSpeechModel; @@ -88,23 +87,23 @@ protected void evaluateQuestionAndAnswer(String question, ChatResponse response, String answer = response.getResult().getOutput().getContent(); logger.info("Question: " + question); logger.info("Answer:" + answer); - PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, + PromptTemplate userPromptTemplate = new PromptTemplate(this.userEvaluatorResource, Map.of("question", question, "answer", answer)); SystemMessage systemMessage; if (factBased) { - systemMessage = new SystemMessage(qaEvaluatorFactBasedAnswerResource); + systemMessage = new SystemMessage(this.qaEvaluatorFactBasedAnswerResource); } else { - systemMessage = new SystemMessage(qaEvaluatorAccurateAnswerResource); + systemMessage = new SystemMessage(this.qaEvaluatorAccurateAnswerResource); } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = chatModel.call(prompt).getResult().getOutput().getContent(); + String yesOrNo = this.chatModel.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { - SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); + SystemMessage notRelatedSystemMessage = new SystemMessage(this.qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = chatModel.call(prompt).getResult().getOutput().getContent(); + String reasonForFailure = this.chatModel.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java index bb23600ad05..d840cf6a791 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.transformer; import java.io.IOException; @@ -73,7 +74,7 @@ public class MetadataTransformerIT { @Test public void testKeywordExtractor() { - var updatedDocuments = keywordMetadataEnricher.apply(List.of(document1, document2)); + var updatedDocuments = this.keywordMetadataEnricher.apply(List.of(this.document1, this.document2)); List> keywords = updatedDocuments.stream().map(d -> d.getMetadata()).toList(); @@ -91,7 +92,7 @@ public void testKeywordExtractor() { @Test public void testSummaryExtractor() { - var updatedDocuments = summaryMetadataEnricher.apply(List.of(document1, document2)); + var updatedDocuments = this.summaryMetadataEnricher.apply(List.of(this.document1, this.document2)); List> summaries = updatedDocuments.stream().map(d -> d.getMetadata()).toList(); @@ -115,34 +116,34 @@ public void testSummaryExtractor() { @Test public void testContentFormatEnricher() { - assertThat(((DefaultContentFormatter) document1.getContentFormatter()).getExcludedEmbedMetadataKeys()) + assertThat(((DefaultContentFormatter) this.document1.getContentFormatter()).getExcludedEmbedMetadataKeys()) .doesNotContain("NewEmbedKey"); - assertThat(((DefaultContentFormatter) document1.getContentFormatter()).getExcludedInferenceMetadataKeys()) + assertThat(((DefaultContentFormatter) this.document1.getContentFormatter()).getExcludedInferenceMetadataKeys()) .doesNotContain("NewInferenceKey"); - assertThat(((DefaultContentFormatter) document2.getContentFormatter()).getExcludedEmbedMetadataKeys()) + assertThat(((DefaultContentFormatter) this.document2.getContentFormatter()).getExcludedEmbedMetadataKeys()) .doesNotContain("NewEmbedKey"); - assertThat(((DefaultContentFormatter) document2.getContentFormatter()).getExcludedInferenceMetadataKeys()) + assertThat(((DefaultContentFormatter) this.document2.getContentFormatter()).getExcludedInferenceMetadataKeys()) .doesNotContain("NewInferenceKey"); - List enrichedDocuments = contentFormatTransformer.apply(List.of(document1, document2)); + List enrichedDocuments = this.contentFormatTransformer.apply(List.of(this.document1, this.document2)); assertThat(enrichedDocuments.size()).isEqualTo(2); var doc1 = enrichedDocuments.get(0); var doc2 = enrichedDocuments.get(1); - assertThat(doc1).isEqualTo(document1); - assertThat(doc2).isEqualTo(document2); + assertThat(doc1).isEqualTo(this.document1); + assertThat(doc2).isEqualTo(this.document2); assertThat(((DefaultContentFormatter) doc1.getContentFormatter()).getTextTemplate()) - .isSameAs(defaultContentFormatter.getTextTemplate()); + .isSameAs(this.defaultContentFormatter.getTextTemplate()); assertThat(((DefaultContentFormatter) doc1.getContentFormatter()).getExcludedEmbedMetadataKeys()) .contains("NewEmbedKey"); assertThat(((DefaultContentFormatter) doc1.getContentFormatter()).getExcludedInferenceMetadataKeys()) .contains("NewInferenceKey"); assertThat(((DefaultContentFormatter) doc2.getContentFormatter()).getTextTemplate()) - .isSameAs(defaultContentFormatter.getTextTemplate()); + .isSameAs(this.defaultContentFormatter.getTextTemplate()); assertThat(((DefaultContentFormatter) doc2.getContentFormatter()).getExcludedEmbedMetadataKeys()) .contains("NewEmbedKey"); assertThat(((DefaultContentFormatter) doc2.getContentFormatter()).getExcludedInferenceMetadataKeys()) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java index 21ca5bc49b3..f58063ccd27 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,49 +13,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.vectorstore; +import java.io.File; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.CleanupMode; import org.junit.jupiter.api.io.TempDir; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.reader.JsonMetadataGenerator; import org.springframework.ai.reader.JsonReader; import org.springframework.ai.vectorstore.SimpleVectorStore; -import org.springframework.ai.reader.JsonMetadataGenerator; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; -import java.io.File; -import java.nio.file.Path; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest public class SimplePersistentVectorStoreIT { + @TempDir(cleanup = CleanupMode.ON_SUCCESS) + Path workingDir; + @Value("file:src/test/resources/data/acme/bikes.json") private Resource bikesJsonResource; @Autowired private EmbeddingModel embeddingModel; - @TempDir(cleanup = CleanupMode.ON_SUCCESS) - Path workingDir; - @Test void persist() { - JsonReader jsonReader = new JsonReader(bikesJsonResource, new ProductMetadataGenerator(), "price", "name", + JsonReader jsonReader = new JsonReader(this.bikesJsonResource, new ProductMetadataGenerator(), "price", "name", "shortDescription", "description", "tags"); List documents = jsonReader.get(); SimpleVectorStore vectorStore = new SimpleVectorStore(this.embeddingModel); vectorStore.add(documents); - File tempFile = new File(workingDir.toFile(), "temp.txt"); + File tempFile = new File(this.workingDir.toFile(), "temp.txt"); vectorStore.save(tempFile); assertThat(tempFile).isNotEmpty(); assertThat(tempFile).content().contains("Velo 99 XR1 AXS"); diff --git a/models/spring-ai-openai/src/test/resources/application-logging-test.properties b/models/spring-ai-openai/src/test/resources/application-logging-test.properties index 8e8b3b2c3c6..4466a718052 100644 --- a/models/spring-ai-openai/src/test/resources/application-logging-test.properties +++ b/models/spring-ai-openai/src/test/resources/application-logging-test.properties @@ -1 +1,17 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + logging.level.org.springframework.ai.chat.client.advisor=DEBUG diff --git a/models/spring-ai-postgresml/pom.xml b/models/spring-ai-postgresml/pom.xml index acf8349f412..0312ebd4f4c 100644 --- a/models/spring-ai-postgresml/pom.xml +++ b/models/spring-ai-postgresml/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java index 80ddb93f238..14ad9d41c47 100644 --- a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java +++ b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.postgresml; import java.sql.Array; @@ -54,34 +55,6 @@ public class PostgresMlEmbeddingModel extends AbstractEmbeddingModel implements private final JdbcTemplate jdbcTemplate; - public enum VectorType { - - PG_ARRAY("", null, (rs, i) -> { - Array embedding = rs.getArray("embedding"); - return EmbeddingUtils.toPrimitive((Float[]) embedding.getArray()); - - }), - - PG_VECTOR("::vector", "vector", (rs, i) -> { - String embedding = rs.getString("embedding"); - return EmbeddingUtils.toPrimitive(Arrays.stream((embedding.substring(1, embedding.length() - 1) - /* remove leading '[' and trailing ']' */.split(","))).map(Float::parseFloat).toList()); - }); - - private final String cast; - - private final String extensionName; - - private final RowMapper rowMapper; - - VectorType(String cast, String extensionName, RowMapper rowMapper) { - this.cast = cast; - this.extensionName = extensionName; - this.rowMapper = rowMapper; - } - - } - /** * a constructor * @param jdbcTemplate JdbcTemplate @@ -237,4 +210,32 @@ public void afterPropertiesSet() { } } + public enum VectorType { + + PG_ARRAY("", null, (rs, i) -> { + Array embedding = rs.getArray("embedding"); + return EmbeddingUtils.toPrimitive((Float[]) embedding.getArray()); + + }), + + PG_VECTOR("::vector", "vector", (rs, i) -> { + String embedding = rs.getString("embedding"); + return EmbeddingUtils.toPrimitive(Arrays.stream((embedding.substring(1, embedding.length() - 1) + /* remove leading '[' and trailing ']' */.split(","))).map(Float::parseFloat).toList()); + }); + + private final String cast; + + private final String extensionName; + + private final RowMapper rowMapper; + + VectorType(String cast, String extensionName, RowMapper rowMapper) { + this.cast = cast; + this.extensionName = extensionName; + this.rowMapper = rowMapper; + } + + } + } diff --git a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptions.java b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptions.java index 1077141b6f5..2650456915c 100644 --- a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptions.java +++ b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.postgresml; import java.util.Map; @@ -61,45 +62,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected PostgresMlEmbeddingOptions options; - - public Builder() { - this.options = new PostgresMlEmbeddingOptions(); - } - - public Builder withTransformer(String transformer) { - this.options.setTransformer(transformer); - return this; - } - - public Builder withVectorType(VectorType vectorType) { - this.options.setVectorType(vectorType); - return this; - } - - public Builder withKwargs(String kwargs) { - this.options.setKwargs(ModelOptionsUtils.objectToMap(kwargs)); - return this; - } - - public Builder withKwargs(Map kwargs) { - this.options.setKwargs(kwargs); - return this; - } - - public Builder withMetadataMode(MetadataMode metadataMode) { - this.options.setMetadataMode(metadataMode); - return this; - } - - public PostgresMlEmbeddingOptions build() { - return this.options; - } - - } - public String getTransformer() { return this.transformer; } @@ -125,7 +87,7 @@ public void setKwargs(Map kwargs) { } public MetadataMode getMetadataMode() { - return metadataMode; + return this.metadataMode; } public void setMetadataMode(MetadataMode metadataMode) { @@ -144,4 +106,43 @@ public Integer getDimensions() { return null; } + public static class Builder { + + protected PostgresMlEmbeddingOptions options; + + public Builder() { + this.options = new PostgresMlEmbeddingOptions(); + } + + public Builder withTransformer(String transformer) { + this.options.setTransformer(transformer); + return this; + } + + public Builder withVectorType(VectorType vectorType) { + this.options.setVectorType(vectorType); + return this; + } + + public Builder withKwargs(String kwargs) { + this.options.setKwargs(ModelOptionsUtils.objectToMap(kwargs)); + return this; + } + + public Builder withKwargs(Map kwargs) { + this.options.setKwargs(kwargs); + return this; + } + + public Builder withMetadataMode(MetadataMode metadataMode) { + this.options.setMetadataMode(metadataMode); + return this; + } + + public PostgresMlEmbeddingOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java index 23697f934c8..64627bc47c0 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.postgresml; import java.time.Duration; @@ -26,13 +27,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; - -import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.ai.embedding.EmbeddingRequest; -import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.embedding.EmbeddingResponseMetadata; -import org.springframework.ai.postgresml.PostgresMlEmbeddingModel.VectorType; - import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; import org.testcontainers.junit.jupiter.Container; @@ -41,6 +35,11 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.postgresml.PostgresMlEmbeddingModel.VectorType; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase; @@ -257,4 +256,4 @@ public static class TestApplication { } -} \ No newline at end of file +} diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java index 07ce531b75c..c0464867ccc 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.postgresml; import java.util.Map; diff --git a/models/spring-ai-qianfan/pom.xml b/models/spring-ai-qianfan/pom.xml index 39ea559702b..379d29eb26c 100644 --- a/models/spring-ai-qianfan/pom.xml +++ b/models/spring-ai-qianfan/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java index 7a9448ef6af..aaf68884c92 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.EmptyUsage; @@ -48,12 +56,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.Collections; -import java.util.List; -import java.util.Map; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal QianFan} @@ -72,14 +74,14 @@ public class QianFanChatModel implements ChatModel, StreamingChatModel { private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); /** - * The default options used for the chat completion requests. + * The retry template used to retry the QianFan API calls. */ - private final QianFanChatOptions defaultOptions; + public final RetryTemplate retryTemplate; /** - * The retry template used to retry the QianFan API calls. + * The default options used for the chat completion requests. */ - public final RetryTemplate retryTemplate; + private final QianFanChatOptions defaultOptions; /** * Low-level access to the QianFan API. diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java index 24164ab7f68..24bff760bc9 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.qianfan.api.QianFanApi; import org.springframework.boot.context.properties.NestedConfigurationProperty; -import java.util.List; - /** * QianFanChatOptions represents the options for performing chat completion using the * QianFan API. It provides methods to set and retrieve various options like model, @@ -85,62 +87,17 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected QianFanChatOptions options; - - public Builder() { - this.options = new QianFanChatOptions(); - } - - public Builder(QianFanChatOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withFrequencyPenalty(Double frequencyPenalty) { - this.options.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withPresencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } - - public Builder withResponseFormat(QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public QianFanChatOptions build() { - return this.options; - } - + public static QianFanChatOptions fromOptions(QianFanChatOptions fromOptions) { + return QianFanChatOptions.builder() + .withModel(fromOptions.getModel()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .build(); } @Override @@ -234,74 +191,93 @@ public Integer getTopK() { public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); return result; } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } QianFanChatOptions other = (QianFanChatOptions) obj; if (this.model == null) { - if (other.model != null) + if (other.model != null) { return false; + } } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; + } if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) + if (other.frequencyPenalty != null) { return false; + } } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { return false; + } if (this.maxTokens == null) { - if (other.maxTokens != null) + if (other.maxTokens != null) { return false; + } } - else if (!this.maxTokens.equals(other.maxTokens)) + else if (!this.maxTokens.equals(other.maxTokens)) { return false; + } if (this.presencePenalty == null) { - if (other.presencePenalty != null) + if (other.presencePenalty != null) { return false; + } } - else if (!this.presencePenalty.equals(other.presencePenalty)) + else if (!this.presencePenalty.equals(other.presencePenalty)) { return false; + } if (this.responseFormat == null) { - if (other.responseFormat != null) + if (other.responseFormat != null) { return false; + } } - else if (!this.responseFormat.equals(other.responseFormat)) + else if (!this.responseFormat.equals(other.responseFormat)) { return false; + } if (this.stop == null) { - if (other.stop != null) + if (other.stop != null) { return false; + } } - else if (!stop.equals(other.stop)) + else if (!this.stop.equals(other.stop)) { return false; + } if (this.temperature == null) { - if (other.temperature != null) + if (other.temperature != null) { return false; + } } - else if (!this.temperature.equals(other.temperature)) + else if (!this.temperature.equals(other.temperature)) { return false; + } if (this.topP == null) { - if (other.topP != null) + if (other.topP != null) { return false; + } } - else if (!topP.equals(other.topP)) + else if (!this.topP.equals(other.topP)) { return false; + } return true; } @@ -310,17 +286,62 @@ public QianFanChatOptions copy() { return fromOptions(this); } - public static QianFanChatOptions fromOptions(QianFanChatOptions fromOptions) { - return QianFanChatOptions.builder() - .withModel(fromOptions.getModel()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .build(); + public static class Builder { + + protected QianFanChatOptions options; + + public Builder() { + this.options = new QianFanChatOptions(); + } + + public Builder(QianFanChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public QianFanChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java index 40323cf626d..f681cac1cb8 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -40,8 +44,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * QianFan Embedding Client implementation. * diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java index 672b68ab2f6..60700cff2f1 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -48,6 +50,29 @@ public static Builder builder() { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + public static class Builder { protected QianFanEmbeddingOptions options; @@ -72,27 +97,4 @@ public QianFanEmbeddingOptions build() { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - public String getUser() { - return user; - } - - public void setUser(String user) { - this.user = user; - } - - @Override - @JsonIgnore - public Integer getDimensions() { - return null; - } - } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageModel.java index de4b7e26fd7..ba2ca408fba 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImageModel; @@ -37,8 +41,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * QianFanImageModel is a class that implements the ImageModel interface. It provides a * client for calling the QianFan image generation API. diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java index 7ddbd701393..d102d34feb3 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.image.ImageOptions; -import java.util.Objects; +import org.springframework.ai.image.ImageOptions; /** * QianFan Image API options. QianFanImageOptions.java @@ -88,50 +90,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final QianFanImageOptions options; - - private Builder() { - this.options = new QianFanImageOptions(); - } - - public Builder withN(Integer n) { - options.setN(n); - return this; - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public Builder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public Builder withStyle(String style) { - options.setStyle(style); - return this; - } - - public Builder withUser(String user) { - options.setUser(user); - return this; - } - - public QianFanImageOptions build() { - return options; - } - - } - @Override public Integer getN() { return this.n; @@ -206,24 +164,72 @@ public void setSize(String size) { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof QianFanImageOptions that)) + } + if (!(o instanceof QianFanImageOptions that)) { return false; - return Objects.equals(n, that.n) && Objects.equals(model, that.model) && Objects.equals(width, that.width) - && Objects.equals(height, that.height) && Objects.equals(size, that.size) - && Objects.equals(style, that.style) && Objects.equals(user, that.user); + } + return Objects.equals(this.n, that.n) && Objects.equals(this.model, that.model) + && Objects.equals(this.width, that.width) && Objects.equals(this.height, that.height) + && Objects.equals(this.size, that.size) && Objects.equals(this.style, that.style) + && Objects.equals(this.user, that.user); } @Override public int hashCode() { - return Objects.hash(n, model, width, height, size, style, user); + return Objects.hash(this.n, this.model, this.width, this.height, this.size, this.style, this.user); } @Override public String toString() { - return "QianFanImageOptions{" + "n=" + n + ", model='" + model + '\'' + ", width=" + width + ", height=" - + height + ", size='" + size + '\'' + ", style='" + style + '\'' + ", user='" + user + '\'' + '}'; + return "QianFanImageOptions{" + "n=" + this.n + ", model='" + this.model + '\'' + ", width=" + this.width + + ", height=" + this.height + ", size='" + this.size + '\'' + ", style='" + this.style + '\'' + + ", user='" + this.user + '\'' + '}'; + } + + public static class Builder { + + private final QianFanImageOptions options; + + private Builder() { + this.options = new QianFanImageOptions(); + } + + public Builder withN(Integer n) { + this.options.setN(n); + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public Builder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public Builder withStyle(String style) { + this.options.setStyle(style); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public QianFanImageOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java index a7205916138..2538e4f8b20 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.aot; import org.springframework.ai.qianfan.api.QianFanApi; diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java index 5a680338feb..da93b16b67e 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; +import java.util.List; +import java.util.function.Predicate; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.qianfan.api.auth.AuthApi; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.ParameterizedTypeReference; @@ -27,11 +34,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.List; -import java.util.function.Predicate; // @formatter:off /** @@ -125,6 +127,70 @@ public QianFanApi(String baseUrl, String apiKey, String secretKey, RestClient.Bu .build(); } + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(ChatCompletionChunk.class) + .takeUntil(SSE_DONE_PREDICATE); + } + + /** + * Creates an embedding vector representing the input text or token array. + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + */ + public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.texts(), "The input can not be null."); + + // The input must not an empty string, and any array must be 16 dimensions or + // less. + Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); + Assert.isTrue(embeddingRequest.texts().size() <= 16, "The list must be 16 dimensions or less"); + + return this.restClient.post() + .uri("/v1/wenxinworkshop/embeddings/{model}?access_token={token}", embeddingRequest.model, getAccessToken()) + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + + }); + } + /** * QianFan Chat Completion Models: * QianFan Model. @@ -157,7 +223,44 @@ public enum ChatModel { } public String getValue() { - return value; + return this.value; + } + } + + /** + * QianFan Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + /** + * DIMENSION: 384 + */ + EMBEDDING_V1("embedding-v1"), + + /** + * DIMENSION: 1024 + */ + BGE_LARGE_ZH("bge_large_zh"), + + /** + * DIMENSION: 1024 + */ + BGE_LARGE_EN("bge_large_en"), + + /** + * DIMENSION: 1024 + */ + TAO_8K("tao_8k"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; } } @@ -348,79 +451,6 @@ public record ChatCompletionChunk( ) { } - /** - * Creates a model response for the given chat conversation. - * - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - return this.webClient.post() - .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(ChatCompletionChunk.class) - .takeUntil(SSE_DONE_PREDICATE); - } - - /** - * QianFan Embeddings Models: - * Embeddings. - */ - public enum EmbeddingModel { - - /** - * DIMENSION: 384 - */ - EMBEDDING_V1("embedding-v1"), - - /** - * DIMENSION: 1024 - */ - BGE_LARGE_ZH("bge_large_zh"), - - /** - * DIMENSION: 1024 - */ - BGE_LARGE_EN("bge_large_en"), - - /** - * DIMENSION: 1024 - */ - TAO_8K("tao_8k"); - - public final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - } - /** * Creates an embedding vector representing the input text. * @@ -502,6 +532,7 @@ public record Embedding( public Embedding(Integer index, float[] embedding) { this(index, embedding, "embedding"); } + } /** @@ -524,32 +555,5 @@ public record EmbeddingList( // @formatter:on } - /** - * Creates an embedding vector representing the input text or token array. - * @param embeddingRequest The embedding request. - * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - */ - public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple - // inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.texts(), "The input can not be null."); - - // The input must not an empty string, and any array must be 16 dimensions or - // less. - Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); - Assert.isTrue(embeddingRequest.texts().size() <= 16, "The list must be 16 dimensions or less"); - - return this.restClient.post() - .uri("/v1/wenxinworkshop/embeddings/{model}?access_token={token}", embeddingRequest.model, getAccessToken()) - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - } // @formatter:on diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java index b269500a4fe..5dd2744f768 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java index 2fb20942f29..2532e52df0e 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.qianfan.api.auth.AuthApi; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; @@ -24,8 +28,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import java.util.List; - /** * QianFan Image API. * @@ -76,6 +78,18 @@ public QianFanImageApi(String baseUrl, String apiKey, String secretKey, RestClie .build(); } + public ResponseEntity createImage(QianFanImageRequest qianFanImageRequest) { + Assert.notNull(qianFanImageRequest, "Image request cannot be null."); + Assert.hasLength(qianFanImageRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("/v1/wenxinworkshop/text2image/{model}?access_token={token}", qianFanImageRequest.model(), + getAccessToken()) + .body(qianFanImageRequest) + .retrieve() + .toEntity(QianFanImageResponse.class); + } + /** * QianFan Image API model. */ @@ -122,24 +136,11 @@ public record QianFanImageResponse( @JsonProperty("created") Long created, @JsonProperty("data") List data) { } - - @JsonInclude(JsonInclude.Include.NON_NULL) - public record Data( - @JsonProperty("index") Integer index, - @JsonProperty("b64_image") String b64Image) { - } // @formatter:onn - public ResponseEntity createImage(QianFanImageRequest qianFanImageRequest) { - Assert.notNull(qianFanImageRequest, "Image request cannot be null."); - Assert.hasLength(qianFanImageRequest.prompt(), "Prompt cannot be empty."); + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("index") Integer index, @JsonProperty("b64_image") String b64Image) { - return this.restClient.post() - .uri("/v1/wenxinworkshop/text2image/{model}?access_token={token}", qianFanImageRequest.model(), - getAccessToken()) - .body(qianFanImageRequest) - .retrieve() - .toEntity(QianFanImageResponse.class); } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java index f8668c97fff..fb1e9723b00 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java @@ -1,10 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api; +import java.util.function.Consumer; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; -import java.util.function.Consumer; - public class QianFanUtils { public static Consumer defaultHeaders() { diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java index 8681343afa0..96070de8c5d 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api.auth; import com.fasterxml.jackson.annotation.JsonProperty; @@ -12,4 +28,5 @@ public record AccessTokenResponse(@JsonProperty("access_token") String accessTok @JsonProperty("refresh_token") String refreshToken, @JsonProperty("expires_in") Long expiresIn, @JsonProperty("session_key") String sessionKey, @JsonProperty("session_secret") String sessionSecret, @JsonProperty("error") String error, @JsonProperty("error_description") String errorDescription, String scope) { + } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AuthApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AuthApi.java index 648e61fd32b..b265c8ce842 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AuthApi.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AuthApi.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api.auth; /** diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java index 0c3e20f2cf0..ec29676eb69 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api.auth; /** @@ -31,35 +47,35 @@ public QianFanAccessToken(AccessTokenResponse accessTokenResponse) { this.sessionKey = accessTokenResponse.sessionKey(); this.sessionSecret = accessTokenResponse.sessionSecret(); this.scope = accessTokenResponse.scope(); - this.refreshTime = getCurrentTimeInSeconds() + (long) ((double) expiresIn * FRACTION_OF_TIME_TO_LIVE); + this.refreshTime = getCurrentTimeInSeconds() + (long) ((double) this.expiresIn * FRACTION_OF_TIME_TO_LIVE); } public String getAccessToken() { - return accessToken; + return this.accessToken; } public String getRefreshToken() { - return refreshToken; + return this.refreshToken; } public Long getExpiresIn() { - return expiresIn; + return this.expiresIn; } public String getSessionKey() { - return sessionKey; + return this.sessionKey; } public String getSessionSecret() { - return sessionSecret; + return this.sessionSecret; } public Long getRefreshTime() { - return refreshTime; + return this.refreshTime; } public String getScope() { - return scope; + return this.scope; } public synchronized boolean needsRefresh() { diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java index b9af294ff29..d92ac2b592c 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api.auth; import org.springframework.http.ResponseEntity; @@ -28,9 +44,13 @@ public QianFanAuthenticator(String authUrl, String apiKey, String secretKey) { this.restClient = RestClient.builder().baseUrl(authUrl).build(); } + public static Builder builder() { + return new Builder(); + } + public QianFanAccessToken requestToken() { ResponseEntity tokenResponseEntity = this.restClient.get() - .uri(OPERATION_PATH, apiKey, secretKey) + .uri(OPERATION_PATH, this.apiKey, this.secretKey) .retrieve() .toEntity(AccessTokenResponse.class); AccessTokenResponse tokenResponse = tokenResponseEntity.getBody(); @@ -63,13 +83,9 @@ public Builder secretKey(String secretKey) { } public QianFanAuthenticator build() { - return new QianFanAuthenticator(DEFAULT_AUTH_URL, apiKey, secretKey); + return new QianFanAuthenticator(DEFAULT_AUTH_URL, this.apiKey, this.secretKey); } } - public static Builder builder() { - return new Builder(); - } - } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java index 6b5921ec932..eaa69e75502 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -26,10 +27,6 @@ */ public class QianFanUsage implements Usage { - public static QianFanUsage from(QianFanApi.Usage usage) { - return new QianFanUsage(usage); - } - private final QianFanApi.Usage usage; protected QianFanUsage(QianFanApi.Usage usage) { @@ -37,6 +34,10 @@ protected QianFanUsage(QianFanApi.Usage usage) { this.usage = usage; } + public static QianFanUsage from(QianFanApi.Usage usage) { + return new QianFanUsage(usage); + } + protected QianFanApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java index c4f76a40273..de3bf0400aa 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.qianfan.api.QianFanApi; diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java index be98b681905..d1c84192d16 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java index 38f34b72cc6..f8dae1f2092 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; +import java.util.List; +import java.util.Objects; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.stringtemplate.v4.ST; +import reactor.core.publisher.Flux; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletion; import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionChunk; @@ -26,11 +33,6 @@ import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionRequest; import org.springframework.ai.qianfan.api.QianFanApi.EmbeddingList; import org.springframework.http.ResponseEntity; -import org.stringtemplate.v4.ST; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; import static org.assertj.core.api.Assertions.assertThat; @@ -46,7 +48,7 @@ public class QianFanApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = qianFanApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.qianFanApi.chatCompletionEntity(new ChatCompletionRequest( List.of(chatCompletionMessage), buildSystemMessage(), "ernie_speed", 0.7, false)); assertThat(response).isNotNull(); @@ -56,7 +58,7 @@ void chatCompletionEntity() { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = qianFanApi.chatCompletionStream(new ChatCompletionRequest( + Flux response = this.qianFanApi.chatCompletionStream(new ChatCompletionRequest( List.of(chatCompletionMessage), buildSystemMessage(), "ernie_speed", 0.7, true)); assertThat(response).isNotNull(); @@ -65,7 +67,8 @@ void chatCompletionStream() { @Test void embeddings() { - ResponseEntity response = qianFanApi.embeddings(new QianFanApi.EmbeddingRequest("Hello world")); + ResponseEntity response = this.qianFanApi + .embeddings(new QianFanApi.EmbeddingRequest("Hello world")); assertThat(response).isNotNull(); assertThat(Objects.requireNonNull(response.getBody()).data()).hasSize(1); diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java index e59baeccd22..978eb7216cc 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.image.ImageMessage; @@ -47,11 +54,6 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -64,25 +66,6 @@ @ExtendWith(MockitoExtension.class) public class QianFanRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private @Mock QianFanApi qianFanApi; @@ -98,13 +81,14 @@ public void onError(RetryContext context, RetryCallback @BeforeEach public void beforeEach() { RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryListener = new TestRetryListener(); + retryTemplate.registerListener(this.retryListener); - chatClient = new QianFanChatModel(qianFanApi, QianFanChatOptions.builder().build(), retryTemplate); - embeddingClient = new QianFanEmbeddingModel(qianFanApi, MetadataMode.EMBED, + this.chatClient = new QianFanChatModel(this.qianFanApi, QianFanChatOptions.builder().build(), retryTemplate); + this.embeddingClient = new QianFanEmbeddingModel(this.qianFanApi, MetadataMode.EMBED, QianFanEmbeddingOptions.builder().build(), retryTemplate); - imageModel = new QianFanImageModel(qianFanImageApi, QianFanImageOptions.builder().build(), retryTemplate); + this.imageModel = new QianFanImageModel(this.qianFanImageApi, QianFanImageOptions.builder().build(), + retryTemplate); } @Test @@ -112,24 +96,24 @@ public void qianFanChatTransientError() { ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 666L, "Response", "STOP", new Usage(10, 10, 10)); - when(qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatClient.call(new Prompt("text")); + var result = this.chatClient.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void qianFanChatNonTransientError() { - when(qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatClient.call(new Prompt("text"))); } @Test @@ -138,25 +122,25 @@ public void qianFanChatStreamTransientError() { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion", 666L, "Response", "", true, null); - when(qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatClient.stream(new Prompt("text")); + var result = this.chatClient.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(Objects.requireNonNull(result.collectList().block()).get(0).getResult().getOutput().getContent()) .isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void qianFanChatStreamNonTransientError() { - when(qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text")).collectList().block()); + assertThrows(RuntimeException.class, () -> this.chatClient.stream(new Prompt("text")).collectList().block()); } @Test @@ -165,24 +149,25 @@ public void qianFanEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList("embedding_list", List.of(embedding), "model", null, null, new Usage(10, 10, 10)); - when(qianFanApi.embeddings(isA(EmbeddingRequest.class))) + when(this.qianFanApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingClient + var result = this.embeddingClient .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void qianFanEmbeddingNonTransientError() { - when(qianFanApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingClient + when(this.qianFanApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingClient .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } @@ -191,25 +176,44 @@ public void qianFanImageTransientError() { var expectedResponse = new QianFanImageResponse("1", 678L, List.of(new Data(1, "b64"))); - when(qianFanImageApi.createImage(isA(QianFanImageRequest.class))) + when(this.qianFanImageApi.createImage(isA(QianFanImageRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); - var result = imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getB64Json()).isEqualTo("b64"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void qianFanImageNonTransientError() { - when(qianFanImageApi.createImage(isA(QianFanImageRequest.class))) + when(this.qianFanImageApi.createImage(isA(QianFanImageRequest.class))) .thenThrow(new RuntimeException("Transient Error 1")); assertThrows(RuntimeException.class, - () -> imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java index 1ed3d9d1af0..46c4c67e5da 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.chat; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -32,11 +39,6 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -61,10 +63,10 @@ class QianFanChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about three famous pirates from the Golden Age of Piracy in english, focusing on their original nicknames and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -73,10 +75,10 @@ void roleTest() { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about three famous pirates from the Golden Age of Piracy in english, focusing on their original nicknames and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = streamingChatModel.stream(prompt); + Flux flux = this.streamingChatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -91,4 +93,4 @@ void streamRoleTest() { assertThat(stitchedResponseContent).contains("Blackbeard"); } -} \ No newline at end of file +} diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java index cd70b836500..4b447ffa36a 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -62,7 +64,7 @@ public class QianFanChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -80,7 +82,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -103,7 +105,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -123,7 +125,7 @@ void observationForStreamingChatOperation() { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java index 39a53d421c3..371d5c2a833 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.embedding; +import java.util.List; + import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; + import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.qianfan.QianFanTestConfiguration; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -43,23 +44,23 @@ class EmbeddingIT { @Test void defaultEmbedding() { - Assertions.assertThat(embeddingModel).isNotNull(); + Assertions.assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); - Assertions.assertThat(embeddingModel.dimensions()).isEqualTo(1024); + Assertions.assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test void batchEmbedding() { - Assertions.assertThat(embeddingModel).isNotNull(); + Assertions.assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World", "HI")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI")); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -69,7 +70,7 @@ void batchEmbedding() { assertThat(embeddingResponse.getResults().get(1)).isNotNull(); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1024); - Assertions.assertThat(embeddingModel.dimensions()).isEqualTo(1024); + Assertions.assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java index c143a63ecfa..5061626a751 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.embedding; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -36,8 +40,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -66,13 +68,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java index 32ea9a4391e..6e4be44c737 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.image; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageOptionsBuilder; @@ -50,7 +52,7 @@ void imageTest() { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java index 6dbab145246..3ddaf41e586 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.image; import io.micrometer.observation.tck.TestObservationRegistry; @@ -20,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; + import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; @@ -67,10 +69,10 @@ void observationForImageOperation() { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-stability-ai/pom.xml b/models/spring-ai-stability-ai/pom.xml index d45acec7d6a..d60bc443303 100644 --- a/models/spring-ai-stability-ai/pom.xml +++ b/models/spring-ai-stability-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java index 648dae1a5d9..7626b17922e 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.stabilityai; -import org.springframework.ai.image.ImageGenerationMetadata; +package org.springframework.ai.stabilityai; import java.util.Objects; +import org.springframework.ai.image.ImageGenerationMetadata; + /** * Represents metadata associated with the image generation process in the StabilityAI * framework. @@ -50,10 +51,12 @@ public String toString() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof StabilityAiImageGenerationMetadata that)) + } + if (!(o instanceof StabilityAiImageGenerationMetadata that)) { return false; + } return Objects.equals(this.finishReason, that.finishReason) && Objects.equals(this.seed, that.seed); } diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java index e1db5e2ac29..35a98059379 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; import java.util.List; import java.util.stream.Collectors; import org.springframework.ai.image.Image; -import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; @@ -51,6 +52,26 @@ public StabilityAiImageModel(StabilityAiApi stabilityAiApi, StabilityAiImageOpti this.defaultOptions = defaultOptions; } + private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(ImagePrompt stabilityAiImagePrompt, + StabilityAiImageOptions optionsToUse) { + return new StabilityAiApi.GenerateImageRequest.Builder() + .withTextPrompts(stabilityAiImagePrompt.getInstructions() + .stream() + .map(message -> new StabilityAiApi.GenerateImageRequest.TextPrompts(message.getText(), + message.getWeight())) + .collect(Collectors.toList())) + .withHeight(optionsToUse.getHeight()) + .withWidth(optionsToUse.getWidth()) + .withCfgScale(optionsToUse.getCfgScale()) + .withClipGuidancePreset(optionsToUse.getClipGuidancePreset()) + .withSampler(optionsToUse.getSampler()) + .withSamples(optionsToUse.getN()) + .withSeed(optionsToUse.getSeed()) + .withSteps(optionsToUse.getSteps()) + .withStylePreset(optionsToUse.getStylePreset()) + .build(); + } + public StabilityAiImageOptions getOptions() { return this.defaultOptions; } @@ -82,26 +103,6 @@ public ImageResponse call(ImagePrompt imagePrompt) { return convertResponse(generateImageResponse); } - private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(ImagePrompt stabilityAiImagePrompt, - StabilityAiImageOptions optionsToUse) { - return new StabilityAiApi.GenerateImageRequest.Builder() - .withTextPrompts(stabilityAiImagePrompt.getInstructions() - .stream() - .map(message -> new StabilityAiApi.GenerateImageRequest.TextPrompts(message.getText(), - message.getWeight())) - .collect(Collectors.toList())) - .withHeight(optionsToUse.getHeight()) - .withWidth(optionsToUse.getWidth()) - .withCfgScale(optionsToUse.getCfgScale()) - .withClipGuidancePreset(optionsToUse.getClipGuidancePreset()) - .withSampler(optionsToUse.getSampler()) - .withSamples(optionsToUse.getN()) - .withSeed(optionsToUse.getSeed()) - .withSteps(optionsToUse.getSteps()) - .withStylePreset(optionsToUse.getStylePreset()) - .build(); - } - private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse generateImageResponse) { List imageGenerationList = generateImageResponse.artifacts() .stream() diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java index e1d7c9efa5a..f3d76b3faf9 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; /** @@ -48,7 +49,7 @@ public enum StyleEnum { @Override public String toString() { - return text; + return this.text; } } \ No newline at end of file diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java index 2ee5b2f7f24..5b3b7f5460d 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai.api; import java.util.List; @@ -73,8 +74,8 @@ public StabilityAiApi(String apiKey, String model, String baseUrl, RestClient.Bu Consumer jsonContentHeaders = headers -> { headers.setBearerAuth(apiKey); headers.setAccept(List.of(MediaType.APPLICATION_JSON)); // base64 in JSON + - // metadata or return - // image in bytes. + // metadata or return + // image in bytes. headers.setContentType(MediaType.APPLICATION_JSON); }; @@ -84,6 +85,15 @@ public StabilityAiApi(String apiKey, String model, String baseUrl, RestClient.Bu .build(); } + public GenerateImageResponse generateImage(GenerateImageRequest request) { + Assert.notNull(request, "The request body can not be null."); + return this.restClient.post() + .uri("/generation/{model}/text-to-image", this.model) + .body(request) + .retrieve() + .body(GenerateImageResponse.class); + } + @JsonInclude(JsonInclude.Include.NON_NULL) public record GenerateImageRequest(@JsonProperty("text_prompts") List textPrompts, @JsonProperty("height") Integer height, @JsonProperty("width") Integer width, @@ -92,15 +102,15 @@ public record GenerateImageRequest(@JsonProperty("text_prompts") List textPrompts; @@ -178,28 +188,23 @@ public Builder withStylePreset(String stylePreset) { } public GenerateImageRequest build() { - return new GenerateImageRequest(textPrompts, height, width, cfgScale, clipGuidancePreset, sampler, - samples, seed, steps, stylePreset); + return new GenerateImageRequest(this.textPrompts, this.height, this.width, this.cfgScale, + this.clipGuidancePreset, this.sampler, this.samples, this.seed, this.steps, this.stylePreset); } } + } @JsonInclude(JsonInclude.Include.NON_NULL) public record GenerateImageResponse(@JsonProperty("result") String result, @JsonProperty("artifacts") List artifacts) { + public record Artifacts(@JsonProperty("seed") long seed, @JsonProperty("base64") String base64, @JsonProperty("finishReason") String finishReason) { + } - } - public GenerateImageResponse generateImage(GenerateImageRequest request) { - Assert.notNull(request, "The request body can not be null."); - return this.restClient.post() - .uri("/generation/{model}/text-to-image", this.model) - .body(request) - .retrieve() - .body(GenerateImageResponse.class); } } diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java index 4bf839e36b7..645e13f1ab3 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai.api; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.image.ImageOptions; import org.springframework.ai.stabilityai.StyleEnum; -import java.util.Objects; - /** * StabilityAiImageOptions is an interface that extends ImageOptions. It provides * additional stability AI specific image options. @@ -288,88 +290,9 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final StabilityAiImageOptions options; - - private Builder() { - this.options = new StabilityAiImageOptions(); - } - - public Builder withN(Integer n) { - options.setN(n); - return this; - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public Builder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public Builder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); - return this; - } - - public Builder withCfgScale(Float cfgScale) { - options.setCfgScale(cfgScale); - return this; - } - - public Builder withClipGuidancePreset(String clipGuidancePreset) { - options.setClipGuidancePreset(clipGuidancePreset); - return this; - } - - public Builder withSampler(String sampler) { - options.setSampler(sampler); - return this; - } - - public Builder withSeed(Long seed) { - options.setSeed(seed); - return this; - } - - public Builder withSteps(Integer steps) { - options.setSteps(steps); - return this; - } - - public Builder withSamples(Integer samples) { - options.setN(samples); - return this; - } - - public Builder withStylePreset(String stylePreset) { - options.setStylePreset(stylePreset); - return this; - } - - public Builder withStylePreset(StyleEnum styleEnum) { - options.setStylePreset(styleEnum.toString()); - return this; - } - - public StabilityAiImageOptions build() { - return options; - } - - } - @Override public Integer getN() { - return n; + return this.n; } public void setN(Integer n) { @@ -378,7 +301,7 @@ public void setN(Integer n) { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -387,7 +310,7 @@ public void setModel(String model) { @Override public Integer getWidth() { - return width; + return this.width; } public void setWidth(Integer width) { @@ -396,7 +319,7 @@ public void setWidth(Integer width) { @Override public Integer getHeight() { - return height; + return this.height; } public void setHeight(Integer height) { @@ -405,7 +328,7 @@ public void setHeight(Integer height) { @Override public String getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(String responseFormat) { @@ -413,7 +336,7 @@ public void setResponseFormat(String responseFormat) { } public Float getCfgScale() { - return cfgScale; + return this.cfgScale; } public void setCfgScale(Float cfgScale) { @@ -421,7 +344,7 @@ public void setCfgScale(Float cfgScale) { } public String getClipGuidancePreset() { - return clipGuidancePreset; + return this.clipGuidancePreset; } public void setClipGuidancePreset(String clipGuidancePreset) { @@ -429,7 +352,7 @@ public void setClipGuidancePreset(String clipGuidancePreset) { } public String getSampler() { - return sampler; + return this.sampler; } public void setSampler(String sampler) { @@ -437,7 +360,7 @@ public void setSampler(String sampler) { } public Long getSeed() { - return seed; + return this.seed; } public void setSeed(Long seed) { @@ -445,7 +368,7 @@ public void setSeed(Long seed) { } public Integer getSteps() { - return steps; + return this.steps; } public void setSteps(Integer steps) { @@ -464,7 +387,7 @@ public void setStyle(String style) { } public String getStylePreset() { - return stylePreset; + return this.stylePreset; } public void setStylePreset(String stylePreset) { @@ -473,30 +396,113 @@ public void setStylePreset(String stylePreset) { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof StabilityAiImageOptions that)) + } + if (!(o instanceof StabilityAiImageOptions that)) { return false; - return Objects.equals(n, that.n) && Objects.equals(model, that.model) && Objects.equals(width, that.width) - && Objects.equals(height, that.height) && Objects.equals(responseFormat, that.responseFormat) - && Objects.equals(cfgScale, that.cfgScale) - && Objects.equals(clipGuidancePreset, that.clipGuidancePreset) && Objects.equals(sampler, that.sampler) - && Objects.equals(seed, that.seed) && Objects.equals(steps, that.steps) - && Objects.equals(stylePreset, that.stylePreset); + } + return Objects.equals(this.n, that.n) && Objects.equals(this.model, that.model) + && Objects.equals(this.width, that.width) && Objects.equals(this.height, that.height) + && Objects.equals(this.responseFormat, that.responseFormat) + && Objects.equals(this.cfgScale, that.cfgScale) + && Objects.equals(this.clipGuidancePreset, that.clipGuidancePreset) + && Objects.equals(this.sampler, that.sampler) && Objects.equals(this.seed, that.seed) + && Objects.equals(this.steps, that.steps) && Objects.equals(this.stylePreset, that.stylePreset); } @Override public int hashCode() { - return Objects.hash(n, model, width, height, responseFormat, cfgScale, clipGuidancePreset, sampler, seed, steps, - stylePreset); + return Objects.hash(this.n, this.model, this.width, this.height, this.responseFormat, this.cfgScale, + this.clipGuidancePreset, this.sampler, this.seed, this.steps, this.stylePreset); } @Override public String toString() { - return "StabilityAiImageOptions{" + "n=" + n + ", model='" + model + '\'' + ", width=" + width + ", height=" - + height + ", responseFormat='" + responseFormat + '\'' + ", cfgScale=" + cfgScale - + ", clipGuidancePreset='" + clipGuidancePreset + '\'' + ", sampler='" + sampler + '\'' + ", seed=" - + seed + ", steps=" + steps + ", stylePreset='" + stylePreset + '\'' + '}'; + return "StabilityAiImageOptions{" + "n=" + this.n + ", model='" + this.model + '\'' + ", width=" + this.width + + ", height=" + this.height + ", responseFormat='" + this.responseFormat + '\'' + ", cfgScale=" + + this.cfgScale + ", clipGuidancePreset='" + this.clipGuidancePreset + '\'' + ", sampler='" + + this.sampler + '\'' + ", seed=" + this.seed + ", steps=" + this.steps + ", stylePreset='" + + this.stylePreset + '\'' + '}'; + } + + public static class Builder { + + private final StabilityAiImageOptions options; + + private Builder() { + this.options = new StabilityAiImageOptions(); + } + + public Builder withN(Integer n) { + this.options.setN(n); + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public Builder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public Builder withResponseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + + public Builder withCfgScale(Float cfgScale) { + this.options.setCfgScale(cfgScale); + return this; + } + + public Builder withClipGuidancePreset(String clipGuidancePreset) { + this.options.setClipGuidancePreset(clipGuidancePreset); + return this; + } + + public Builder withSampler(String sampler) { + this.options.setSampler(sampler); + return this; + } + + public Builder withSeed(Long seed) { + this.options.setSeed(seed); + return this; + } + + public Builder withSteps(Integer steps) { + this.options.setSteps(steps); + return this; + } + + public Builder withSamples(Integer samples) { + this.options.setN(samples); + return this; + } + + public Builder withStylePreset(String stylePreset) { + this.options.setStylePreset(stylePreset); + return this; + } + + public Builder withStylePreset(StyleEnum styleEnum) { + this.options.setStylePreset(styleEnum.toString()); + return this; + } + + public StabilityAiImageOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java index 0b0c49ee749..bd98867f92a 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.stabilityai; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.stabilityai.api.StabilityAiApi; +package org.springframework.ai.stabilityai; import java.io.File; import java.io.FileOutputStream; @@ -25,6 +22,11 @@ import java.util.Base64; import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.stabilityai.api.StabilityAiApi; + import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "STABILITYAI_API_KEY", matches = ".*") @@ -32,6 +34,21 @@ public class StabilityAiApiIT { StabilityAiApi stabilityAiApi = new StabilityAiApi(System.getenv("STABILITYAI_API_KEY")); + private static void writeToFile(List artifacts) throws IOException { + int counter = 0; + String systemTempDir = System.getProperty("java.io.tmpdir"); + for (StabilityAiApi.GenerateImageResponse.Artifacts artifact : artifacts) { + counter++; + byte[] imageBytes = Base64.getDecoder().decode(artifact.base64()); + String fileName = String.format("dog%d.png", counter); + String filePath = systemTempDir + File.separator + fileName; + File file = new File(filePath); + try (FileOutputStream fos = new FileOutputStream(file)) { + fos.write(imageBytes); + } + } + } + @Test void generateImage() throws IOException { @@ -48,7 +65,7 @@ void generateImage() throws IOException { .withSteps(30) .withStylePreset("photographic"); StabilityAiApi.GenerateImageRequest request = builder.build(); - StabilityAiApi.GenerateImageResponse response = stabilityAiApi.generateImage(request); + StabilityAiApi.GenerateImageResponse response = this.stabilityAiApi.generateImage(request); assertThat(response).isNotNull(); List artifacts = response.artifacts(); @@ -61,19 +78,4 @@ void generateImage() throws IOException { } - private static void writeToFile(List artifacts) throws IOException { - int counter = 0; - String systemTempDir = System.getProperty("java.io.tmpdir"); - for (StabilityAiApi.GenerateImageResponse.Artifacts artifact : artifacts) { - counter++; - byte[] imageBytes = Base64.getDecoder().decode(artifact.base64()); - String fileName = String.format("dog%d.png", counter); - String filePath = systemTempDir + File.separator + fileName; - File file = new File(filePath); - try (FileOutputStream fos = new FileOutputStream(file)) { - fos.write(imageBytes); - } - } - } - } diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageModelIT.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageModelIT.java index b2de03a19df..9548cefcab9 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageModelIT.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,25 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.Base64; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.image.Image; -import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.util.Base64; - import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = StabilityAiImageTestConfiguration.class) @@ -41,6 +42,16 @@ public class StabilityAiImageModelIT { @Autowired protected ImageModel stabilityAiImageModel; + private static void writeFile(Image image) throws IOException { + byte[] imageBytes = Base64.getDecoder().decode(image.getB64Json()); + String systemTempDir = System.getProperty("java.io.tmpdir"); + String filePath = systemTempDir + File.separator + "dog.png"; + File file = new File(filePath); + try (FileOutputStream fos = new FileOutputStream(file)) { + fos.write(imageBytes); + } + } + @Test void imageAsBase64Test() throws IOException { @@ -64,14 +75,4 @@ void imageAsBase64Test() throws IOException { writeFile(image); } - private static void writeFile(Image image) throws IOException { - byte[] imageBytes = Base64.getDecoder().decode(image.getB64Json()); - String systemTempDir = System.getProperty("java.io.tmpdir"); - String filePath = systemTempDir + File.separator + "dog.png"; - File file = new File(filePath); - try (FileOutputStream fos = new FileOutputStream(file)) { - fos.write(imageBytes); - } - } - } diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java index c5271ff0085..27690f7b5b3 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; import org.springframework.ai.stabilityai.api.StabilityAiApi; diff --git a/models/spring-ai-transformers/pom.xml b/models/spring-ai-transformers/pom.xml index 086e0064f15..f32266818f4 100644 --- a/models/spring-ai-transformers/pom.xml +++ b/models/spring-ai-transformers/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java index 8fa20b199a9..a074571827f 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers; import java.io.File; diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java index e98db7ef2d8..213f578aaba 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers; import java.nio.FloatBuffer; @@ -23,8 +24,22 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.modality.nlp.preprocess.Tokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import io.micrometer.observation.ObservationRegistry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -43,20 +58,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import ai.djl.huggingface.tokenizers.Encoding; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import ai.djl.modality.nlp.preprocess.Tokenizer; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.DataType; -import ai.djl.ndarray.types.Shape; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; -import io.micrometer.observation.ObservationRegistry; - /** * An implementation of the AbstractEmbeddingModel that uses ONNX-based Transformer models * for text embeddings. @@ -79,10 +80,6 @@ */ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implements InitializingBean { - private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class); - - private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); - // ONNX tokenizer for the all-MiniLM-L6-v2 generative public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; @@ -92,8 +89,27 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement public final static String DEFAULT_MODEL_OUTPUT_NAME = "last_hidden_state"; + private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class); + + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + private final static int EMBEDDING_AXIS = 1; + /** + * Specifies what parts of the {@link Document}'s content and metadata will be used + * for computing the embeddings. Applicable for the {@link #embed(Document)} method + * only. Has no effect on the {@link #embed(String)} or {@link #embed(List)}. Defaults + * to {@link MetadataMode#NONE}. + */ + private final MetadataMode metadataMode; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + public Map tokenizerOptions = Map.of(); + private Resource tokenizerResource = toResource(DEFAULT_ONNX_TOKENIZER_URI); private Resource modelResource = toResource(DEFAULT_ONNX_MODEL_URI); @@ -116,14 +132,6 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement */ private OrtSession session; - /** - * Specifies what parts of the {@link Document}'s content and metadata will be used - * for computing the embeddings. Applicable for the {@link #embed(Document)} method - * only. Has no effect on the {@link #embed(String)} or {@link #embed(List)}. Defaults - * to {@link MetadataMode#NONE}. - */ - private final MetadataMode metadataMode; - /** * Resource cache directory. Used to cache remote resources, such as the ONNX models, * to the local file system. @@ -143,17 +151,10 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement */ private ResourceCacheService cacheService; - public Map tokenizerOptions = Map.of(); - private String modelOutputName = DEFAULT_MODEL_OUTPUT_NAME; private Set onnxModelInputs; - /** - * Observation registry used for instrumentation. - */ - private final ObservationRegistry observationRegistry; - /** * Conventions to use for generating observations. */ @@ -174,6 +175,10 @@ public TransformersEmbeddingModel(MetadataMode metadataMode, ObservationRegistry this.observationRegistry = observationRegistry; } + private static Resource toResource(String uri) { + return new DefaultResourceLoader().getResource(uri); + } + public void setTokenizerOptions(Map tokenizerOptions) { this.tokenizerOptions = tokenizerOptions; } @@ -360,7 +365,7 @@ private Map removeUnknownModelInputs(Map return modelInputs.entrySet() .stream() - .filter(a -> onnxModelInputs.contains(a.getKey())) + .filter(a -> this.onnxModelInputs.contains(a.getKey())) .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); } @@ -399,10 +404,6 @@ private NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask) { return sumEmbeddings.div(sumMask); } - private static Resource toResource(String uri) { - return new DefaultResourceLoader().getResource(uri); - } - /** * Use the provided convention for reporting observation data * @param observationConvention The provided convention @@ -412,4 +413,4 @@ public void setObservationConvention(EmbeddingModelObservationConvention observa this.observationConvention = observationConvention; } -} \ No newline at end of file +} diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/ResourceCacheServiceTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/ResourceCacheServiceTests.java index a8da322227a..3e6aff4135d 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/ResourceCacheServiceTests.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/ResourceCacheServiceTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers; import java.io.File; @@ -37,27 +38,27 @@ public class ResourceCacheServiceTests { @Test public void fileResourcesAreExcludedByDefault() throws IOException { - var cache = new ResourceCacheService(tempDir); + var cache = new ResourceCacheService(this.tempDir); var originalResourceUri = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; var cachedResource = cache.getCachedResource(originalResourceUri); assertThat(cachedResource).isEqualTo(new DefaultResourceLoader().getResource(originalResourceUri)); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(0); + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(0); } @Test public void cacheFileResources() throws IOException { - var cache = new ResourceCacheService(tempDir); + var cache = new ResourceCacheService(this.tempDir); cache.setExcludedUriSchemas(List.of()); // erase the excluded schema names, - // including 'file'. + // including 'file'. var originalResourceUri = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; var cachedResource1 = cache.getCachedResource(originalResourceUri); assertThat(cachedResource1).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri)); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1); - assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1); + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(1); + assertThat(Files.list(Files.list(this.tempDir.toPath()).iterator().next()).count()).isEqualTo(1); // Attempt to cache the same resource again should return the already cached // resource. @@ -66,17 +67,17 @@ public void cacheFileResources() throws IOException { assertThat(cachedResource2).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri)); assertThat(cachedResource2).isEqualTo(cachedResource1); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1); - assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1); + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(1); + assertThat(Files.list(Files.list(this.tempDir.toPath()).iterator().next()).count()).isEqualTo(1); } @Test public void cacheFileResourcesFromSameParentFolder() throws IOException { - var cache = new ResourceCacheService(tempDir); + var cache = new ResourceCacheService(this.tempDir); cache.setExcludedUriSchemas(List.of()); // erase the excluded schema names, - // including 'file'. + // including 'file'. var originalResourceUri1 = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; var cachedResource1 = cache.getCachedResource(originalResourceUri1); @@ -89,23 +90,23 @@ public void cacheFileResourcesFromSameParentFolder() throws IOException { assertThat(cachedResource2).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri1)); assertThat(cachedResource2).isNotEqualTo(cachedResource1); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1) + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(1) .describedAs( "As both resources come from the same parent segments they should be cached in a single common parent."); - assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(2); + assertThat(Files.list(Files.list(this.tempDir.toPath()).iterator().next()).count()).isEqualTo(2); } @Test public void cacheHttpResources() throws IOException { - var cache = new ResourceCacheService(tempDir); + var cache = new ResourceCacheService(this.tempDir); var originalResourceUri1 = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties"; var cachedResource1 = cache.getCachedResource(originalResourceUri1); assertThat(cachedResource1).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri1)); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1); - assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1); + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(1); + assertThat(Files.list(Files.list(this.tempDir.toPath()).iterator().next()).count()).isEqualTo(1); } } diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java index ec3c9c5ad50..3f91ad52eb4 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.transformers; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.transformers; import java.util.List; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; @@ -35,8 +37,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}. @@ -59,13 +60,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java index 8c2fb3f0194..02056499b02 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers; import java.text.DecimalFormat; diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java index be37188b3ff..8119bbca5bf 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers.samples; import java.nio.FloatBuffer; @@ -125,4 +126,4 @@ public static NDArray create(float[][][] data, NDManager manager) { return manager.create(buffer, new Shape(data.length, data[0].length, data[0][0].length)); } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-embedding/pom.xml b/models/spring-ai-vertex-ai-embedding/pom.xml index b94de96bd87..0ce34354e8f 100644 --- a/models/spring-ai-vertex-ai-embedding/pom.xml +++ b/models/spring-ai-vertex-ai-embedding/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java index 3f2fb6ec71b..7f2cc2eb48f 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding; import java.io.IOException; -import org.springframework.util.StringUtils; - import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import org.springframework.util.StringUtils; + /** * VertexAiEmbeddingConnectionDetails represents the details of a connection to the Vertex * AI embedding service. It provides methods to access the project ID, location, @@ -33,15 +34,13 @@ */ public class VertexAiEmbeddingConnectionDetails { - private static final String DEFAULT_LOCATION = "us-central1"; - public static final String DEFAULT_ENDPOINT = "us-central1-aiplatform.googleapis.com:443"; public static final String DEFAULT_ENDPOINT_SUFFIX = "-aiplatform.googleapis.com:443"; public static final String DEFAULT_PUBLISHER = "google"; - private PredictionServiceSettings predictionServiceSettings; + private static final String DEFAULT_LOCATION = "us-central1"; /** * Your project ID. @@ -59,6 +58,8 @@ public class VertexAiEmbeddingConnectionDetails { private final String publisher; + private PredictionServiceSettings predictionServiceSettings; + public VertexAiEmbeddingConnectionDetails(String endpoint, String projectId, String location, String publisher) { this.projectId = projectId; this.location = location; @@ -76,6 +77,27 @@ public static Builder builder() { return new Builder(); } + public String getProjectId() { + return this.projectId; + } + + public String getLocation() { + return this.location; + } + + public String getPublisher() { + return this.publisher; + } + + public EndpointName getEndpointName(String modelName) { + return EndpointName.ofProjectLocationPublisherModelName(this.projectId, this.location, this.publisher, + modelName); + } + + public PredictionServiceSettings getPredictionServiceSettings() { + return this.predictionServiceSettings; + } + public static class Builder { /** @@ -143,25 +165,4 @@ public VertexAiEmbeddingConnectionDetails build() { } - public String getProjectId() { - return this.projectId; - } - - public String getLocation() { - return this.location; - } - - public String getPublisher() { - return this.publisher; - } - - public EndpointName getEndpointName(String modelName) { - return EndpointName.ofProjectLocationPublisherModelName(this.projectId, this.location, this.publisher, - modelName); - } - - public PredictionServiceSettings getPredictionServiceSettings() { - return this.predictionServiceSettings; - } - } diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java index ef0152c23a1..602afbd80e3 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vertexai.embedding; import org.springframework.ai.chat.metadata.Usage; diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java index a160baeda4e..caac760df72 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java @@ -1,33 +1,33 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vertexai.embedding; import java.nio.charset.StandardCharsets; import java.util.Base64; -import java.util.List; - -import org.springframework.util.Assert; -import org.springframework.util.MimeType; -import org.springframework.util.StringUtils; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Struct; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.StringUtils; + /** * Utility class for constructing parameter objects for Vertex AI embedding requests. * @@ -36,6 +36,39 @@ */ public abstract class VertexAiEmbeddingUtils { + public static Value valueOf(boolean n) { + return Value.newBuilder().setBoolValue(n).build(); + } + + public static Value valueOf(String s) { + return Value.newBuilder().setStringValue(s).build(); + } + + public static Value valueOf(int n) { + return Value.newBuilder().setNumberValue(n).build(); + } + + public static Value valueOf(Struct struct) { + return Value.newBuilder().setStructValue(struct).build(); + } + + // Convert a Json string to a protobuf.Value + public static Value jsonToValue(String json) throws InvalidProtocolBufferException { + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + public static float[] toVector(Value value) { + float[] floats = new float[value.getListValue().getValuesList().size()]; + int index = 0; + for (Value v : value.getListValue().getValuesList()) { + double d = v.getNumberValue(); + floats[index++] = Double.valueOf(d).floatValue(); + } + return floats; + } + ////////////////////////////////////////////////////// // Text Only ////////////////////////////////////////////////////// @@ -404,37 +437,4 @@ else if (this.gcsUri != null) { } - public static Value valueOf(boolean n) { - return Value.newBuilder().setBoolValue(n).build(); - } - - public static Value valueOf(String s) { - return Value.newBuilder().setStringValue(s).build(); - } - - public static Value valueOf(int n) { - return Value.newBuilder().setNumberValue(n).build(); - } - - public static Value valueOf(Struct struct) { - return Value.newBuilder().setStructValue(struct).build(); - } - - // Convert a Json string to a protobuf.Value - public static Value jsonToValue(String json) throws InvalidProtocolBufferException { - Value.Builder builder = Value.newBuilder(); - JsonFormat.parser().merge(json, builder); - return builder.build(); - } - - public static float[] toVector(Value value) { - float[] floats = new float[value.getListValue().getValuesList().size()]; - int index = 0; - for (Value v : value.getListValue().getValuesList()) { - double d = v.getNumberValue(); - floats[index++] = Double.valueOf(d).floatValue(); - } - return floats; - } - } diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java index 5dca008c074..b13efe0dd95 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding.multimodal; +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictRequest; import com.google.cloud.aiplatform.v1.PredictResponse; @@ -23,7 +31,7 @@ import com.google.protobuf.Value; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.model.Media; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.DocumentEmbeddingModel; @@ -34,6 +42,7 @@ import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.EmbeddingResultMetadata; import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType; +import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; @@ -47,13 +56,6 @@ import org.springframework.util.MimeTypeUtils; import org.springframework.util.StringUtils; -import java.util.ArrayList; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; - /** * Implementation of the Vertex AI Multimodal Embedding Model. Note: This implementation * is not yet fully functional and is subject to change. @@ -66,8 +68,6 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel private static final Logger logger = LoggerFactory.getLogger(VertexAiMultimodalEmbeddingModel.class); - public final VertexAiMultimodalEmbeddingOptions defaultOptions; - private static final MimeType TEXT_MIME_TYPE = MimeTypeUtils.parseMimeType("text/*"); private static final MimeType IMAGE_MIME_TYPE = MimeTypeUtils.parseMimeType("image/*"); @@ -77,6 +77,13 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel private static final List SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG, MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp")); + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream + .of(VertexAiMultimodalEmbeddingModelName.values()) + .collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName, + VertexAiMultimodalEmbeddingModelName::getDimensions)); + + public final VertexAiMultimodalEmbeddingOptions defaultOptions; + private final VertexAiEmbeddingConnectionDetails connectionDetails; public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, @@ -123,9 +130,6 @@ public EmbeddingResponse call(DocumentEmbeddingRequest request) { return finalResponse; } - record DocumentMetadata(String documentId, MimeType mimeType, Object data) { - } - private EmbeddingResponse doSingleDocumentPrediction(PredictionServiceClient client, EndpointName endpointName, Document document, VertexAiMultimodalEmbeddingOptions mergedOptions) throws InvalidProtocolBufferException { @@ -252,9 +256,8 @@ public int dimensions() { return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), 768); } - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream - .of(VertexAiMultimodalEmbeddingModelName.values()) - .collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName, - VertexAiMultimodalEmbeddingModelName::getDimensions)); + record DocumentMetadata(String documentId, MimeType mimeType, Object data) { + + } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java index 5dc546b5f3c..750d9816a0b 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding.multimodal; import org.springframework.ai.model.EmbeddingModelDescription; diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java index 78a75fef20a..89762581c52 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.multimodal; -import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.util.StringUtils; +package org.springframework.ai.vertexai.embedding.multimodal; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.util.StringUtils; + /** * Class representing the options for Vertex AI Multimodal Embedding. * @@ -105,6 +106,48 @@ public static Builder builder() { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public Integer getVideoStartOffsetSec() { + return this.videoStartOffsetSec; + } + + public void setVideoStartOffsetSec(Integer videoStartOffsetSec) { + this.videoStartOffsetSec = videoStartOffsetSec; + } + + public Integer getVideoEndOffsetSec() { + return this.videoEndOffsetSec; + } + + public void setVideoEndOffsetSec(Integer videoEndOffsetSec) { + this.videoEndOffsetSec = videoEndOffsetSec; + } + + public Integer getVideoIntervalSec() { + return this.videoIntervalSec; + } + + public void setVideoIntervalSec(Integer videoIntervalSec) { + this.videoIntervalSec = videoIntervalSec; + } + public static class Builder { protected VertexAiMultimodalEmbeddingOptions options; @@ -168,46 +211,4 @@ public VertexAiMultimodalEmbeddingOptions build() { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - public Integer getDimensions() { - return this.dimensions; - } - - public void setDimensions(Integer dimensions) { - this.dimensions = dimensions; - } - - public Integer getVideoStartOffsetSec() { - return this.videoStartOffsetSec; - } - - public void setVideoStartOffsetSec(Integer videoStartOffsetSec) { - this.videoStartOffsetSec = videoStartOffsetSec; - } - - public Integer getVideoEndOffsetSec() { - return this.videoEndOffsetSec; - } - - public void setVideoEndOffsetSec(Integer videoEndOffsetSec) { - this.videoEndOffsetSec = videoEndOffsetSec; - } - - public Integer getVideoIntervalSec() { - return this.videoIntervalSec; - } - - public void setVideoIntervalSec(Integer videoIntervalSec) { - this.videoIntervalSec = videoIntervalSec; - } - } diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 26da5fea0c0..31d2846e9c5 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding.text; import java.io.IOException; @@ -22,6 +23,13 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Value; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -46,14 +54,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import com.google.cloud.aiplatform.v1.EndpointName; -import com.google.cloud.aiplatform.v1.PredictRequest; -import com.google.cloud.aiplatform.v1.PredictResponse; -import com.google.cloud.aiplatform.v1.PredictionServiceClient; -import com.google.protobuf.Value; - -import io.micrometer.observation.ObservationRegistry; - /** * A class representing a Vertex AI Text Embedding Model. * @@ -65,6 +65,11 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream + .of(VertexAiTextEmbeddingModelName.values()) + .collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName, + VertexAiTextEmbeddingModelName::getDimensions)); + public final VertexAiTextEmbeddingOptions defaultOptions; private final VertexAiEmbeddingConnectionDetails connectionDetails; @@ -131,7 +136,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, finalOptions); - PredictResponse embeddingResponse = retryTemplate + PredictResponse embeddingResponse = this.retryTemplate .execute(context -> getPredictResponse(client, predictRequestBuilder)); int index = 0; @@ -228,11 +233,6 @@ public int dimensions() { return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions()); } - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream - .of(VertexAiTextEmbeddingModelName.values()) - .collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName, - VertexAiTextEmbeddingModelName::getDimensions)); - /** * Use the provided convention for reporting observation data * @param observationConvention The provided convention @@ -242,4 +242,4 @@ public void setObservationConvention(EmbeddingModelObservationConvention observa this.observationConvention = observationConvention; } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java index c49471d061c..327d7950c27 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding.text; import org.springframework.ai.model.EmbeddingModelDescription; diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java index fe08b2d4bf2..4de1f3375fe 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.text; -import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.util.StringUtils; +package org.springframework.ai.vertexai.embedding.text; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.util.StringUtils; + /** * @author Christian Tzolov * @since 1.0.0 @@ -31,6 +32,100 @@ public class VertexAiTextEmbeddingOptions implements EmbeddingOptions { public static final String DEFAULT_MODEL_NAME = VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName(); + /** + * The embedding model name to use. Supported models are: text-embedding-004, + * text-multilingual-embedding-002 and multimodalembedding@001. + */ + private @JsonProperty("model") String model; + + // @formatter:off + + /** + * The intended downstream application to help the model produce better quality embeddings. + * Not all model versions support all task types. + */ + private @JsonProperty("task") TaskType taskType; + + /** + * The number of dimensions the resulting output embeddings should have. + * Supported for model version 004 and later. You can use this parameter to reduce the + * embedding size, for example, for storage optimization. + */ + private @JsonProperty("dimensions") Integer dimensions; + + /** + * Optional title, only valid with task_type=RETRIEVAL_DOCUMENT. + */ + private @JsonProperty("title") String title; + + /** + * When set to true, input text will be truncated. When set to false, an error is returned + * if the input text is longer than the maximum length supported by the model. Defaults to true. + */ + private @JsonProperty("autoTruncate") Boolean autoTruncate; + + public static Builder builder() { + return new Builder(); + } + + + // @formatter:on + + public VertexAiTextEmbeddingOptions initializeDefaults() { + + if (this.getTaskType() == null) { + this.setTaskType(TaskType.RETRIEVAL_DOCUMENT); + } + + if (StringUtils.hasText(this.getTitle()) && this.getTaskType() != TaskType.RETRIEVAL_DOCUMENT) { + throw new IllegalArgumentException("Title is only valid with task_type=RETRIEVAL_DOCUMENT"); + } + + return this; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public TaskType getTaskType() { + return this.taskType; + } + + public void setTaskType(TaskType taskType) { + this.taskType = taskType; + } + + @Override + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public String getTitle() { + return this.title; + } + + public void setTitle(String user) { + this.title = user; + } + + public Boolean getAutoTruncate() { + return this.autoTruncate; + } + + public void setAutoTruncate(Boolean autoTruncate) { + this.autoTruncate = autoTruncate; + } + public enum TaskType { /** @@ -71,45 +166,6 @@ public enum TaskType { } - // @formatter:off - /** - * The embedding model name to use. Supported models are: - * text-embedding-004, text-multilingual-embedding-002 and multimodalembedding@001. - */ - private @JsonProperty("model") String model; - - /** - * The intended downstream application to help the model produce better quality embeddings. - * Not all model versions support all task types. - */ - private @JsonProperty("task") TaskType taskType; - - /** - * The number of dimensions the resulting output embeddings should have. - * Supported for model version 004 and later. You can use this parameter to reduce the - * embedding size, for example, for storage optimization. - */ - private @JsonProperty("dimensions") Integer dimensions; - - /** - * Optional title, only valid with task_type=RETRIEVAL_DOCUMENT. - */ - private @JsonProperty("title") String title; - - - /** - * When set to true, input text will be truncated. When set to false, an error is returned - * if the input text is longer than the maximum length supported by the model. Defaults to true. - */ - private @JsonProperty("autoTruncate") Boolean autoTruncate; - - - // @formatter:on - - public static Builder builder() { - return new Builder(); - } - public static class Builder { protected VertexAiTextEmbeddingOptions options; @@ -170,59 +226,4 @@ public VertexAiTextEmbeddingOptions build() { } - public VertexAiTextEmbeddingOptions initializeDefaults() { - - if (this.getTaskType() == null) { - this.setTaskType(TaskType.RETRIEVAL_DOCUMENT); - } - - if (StringUtils.hasText(this.getTitle()) && this.getTaskType() != TaskType.RETRIEVAL_DOCUMENT) { - throw new IllegalArgumentException("Title is only valid with task_type=RETRIEVAL_DOCUMENT"); - } - - return this; - } - - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - public TaskType getTaskType() { - return this.taskType; - } - - public void setTaskType(TaskType taskType) { - this.taskType = taskType; - } - - @Override - public Integer getDimensions() { - return this.dimensions; - } - - public void setDimensions(Integer dimensions) { - this.dimensions = dimensions; - } - - public String getTitle() { - return this.title; - } - - public void setTitle(String user) { - this.title = user; - } - - public Boolean getAutoTruncate() { - return this.autoTruncate; - } - - public void setAutoTruncate(Boolean autoTruncate) { - this.autoTruncate = autoTruncate; - } - } diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java index 6caf874323e..b92079d5596 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.multimodal; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vertexai.embedding.multimodal; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.model.Media; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.DocumentEmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResultMetadata; +import org.springframework.ai.model.Media; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -33,6 +33,8 @@ import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; +import static org.assertj.core.api.Assertions.assertThat; + @SpringBootTest(classes = VertexAiMultimodalEmbeddingModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") @@ -49,7 +51,7 @@ void multipleInstancesEmbedding() { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(new Document("Hello World"), new Document("Hello World2")); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) @@ -76,7 +78,7 @@ void multipleInstancesEmbedding() { .as("Total tokens in metadata should be 0") .isEqualTo(0L); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test @@ -86,7 +88,7 @@ void textContentEmbedding() { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) @@ -98,18 +100,18 @@ void textContentEmbedding() { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test void textMediaEmbedding() { - assertThat(multiModelEmbeddingModel).isNotNull(); + assertThat(this.multiModelEmbeddingModel).isNotNull(); var document = Document.builder().withMedia(new Media(MimeTypeUtils.TEXT_PLAIN, "Hello World")).build(); DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) @@ -121,7 +123,7 @@ void textMediaEmbedding() { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test @@ -133,7 +135,7 @@ void imageEmbedding() { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); @@ -147,7 +149,7 @@ void imageEmbedding() { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test @@ -159,7 +161,7 @@ void videoEmbedding() { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); @@ -172,7 +174,7 @@ void videoEmbedding() { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test @@ -186,7 +188,7 @@ void textImageAndVideoEmbedding() { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(3); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) @@ -206,7 +208,7 @@ void textImageAndVideoEmbedding() { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @SpringBootConfiguration diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java index 090d683e5a4..baaeedd8acb 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -20,12 +20,11 @@ import com.google.cloud.aiplatform.v1.PredictRequest; import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.retry.support.RetryTemplate; -import java.io.IOException; - public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel { private PredictionServiceClient mockPredictionServiceClient; @@ -43,16 +42,16 @@ public void setMockPredictionServiceClient(PredictionServiceClient mockPredictio @Override PredictionServiceClient createPredictionServiceClient() { - if (mockPredictionServiceClient != null) { - return mockPredictionServiceClient; + if (this.mockPredictionServiceClient != null) { + return this.mockPredictionServiceClient; } return super.createPredictionServiceClient(); } @Override PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { - if (mockPredictionServiceClient != null) { - return mockPredictionServiceClient.predict(predictRequestBuilder.build()); + if (this.mockPredictionServiceClient != null) { + return this.mockPredictionServiceClient.predict(predictRequestBuilder.build()); } return super.getPredictResponse(client, predictRequestBuilder); } @@ -64,8 +63,8 @@ public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictReque @Override protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName, VertexAiTextEmbeddingOptions finalOptions) { - if (mockPredictRequestBuilder != null) { - return mockPredictRequestBuilder; + if (this.mockPredictRequestBuilder != null) { + return this.mockPredictRequestBuilder; } return super.getPredictRequestBuilder(request, endpointName, finalOptions); } diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java index f98c96b1baa..d1701b7a8d3 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.text; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vertexai.embedding.text; import java.util.List; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; @@ -30,6 +30,8 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; +import static org.assertj.core.api.Assertions.assertThat; + @SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") @@ -43,11 +45,11 @@ class VertexAiTextEmbeddingModelIT { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "text-embedding-004", "text-multilingual-embedding-002" }) void defaultEmbedding(String modelName) { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); var options = VertexAiTextEmbeddingOptions.builder().withModel(modelName).build(); - EmbeddingResponse embeddingResponse = embeddingModel + EmbeddingResponse embeddingResponse = this.embeddingModel .call(new EmbeddingRequest(List.of("Hello World", "World is Big"), options)); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -60,7 +62,7 @@ void defaultEmbedding(String modelName) { .as("Total tokens in metadata should be 5") .isEqualTo(5L); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @SpringBootConfiguration diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java index f6ac7c5b531..9a277d40348 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.text; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vertexai.embedding.text; import java.util.List; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; @@ -36,9 +39,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}. @@ -66,13 +67,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java index 2638e97e7cd..5757fe5a4fe 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,9 +16,11 @@ package org.springframework.ai.vertexai.embedding.text; +import java.util.List; + import com.google.cloud.aiplatform.v1.PredictRequest; -import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.cloud.aiplatform.v1.PredictionServiceSettings; import com.google.protobuf.Struct; import com.google.protobuf.Value; @@ -27,8 +29,9 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.embedding.EmbeddingResponse; + import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; @@ -37,8 +40,6 @@ import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; @@ -52,25 +53,6 @@ @ExtendWith(MockitoExtension.class) public class VertexAiTextEmbeddingRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -91,15 +73,15 @@ public void onError(RetryContext context, RetryCallback @BeforeEach public void setUp() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); - - embeddingModel = new TestVertexAiTextEmbeddingModel(mockConnectionDetails, - VertexAiTextEmbeddingOptions.builder().build(), retryTemplate); - embeddingModel.setMockPredictionServiceClient(mockPredictionServiceClient); - embeddingModel.setMockPredictRequestBuilder(mockPredictRequestBuilder); - when(mockPredictRequestBuilder.build()).thenReturn(PredictRequest.getDefaultInstance()); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); + + this.embeddingModel = new TestVertexAiTextEmbeddingModel(this.mockConnectionDetails, + VertexAiTextEmbeddingOptions.builder().build(), this.retryTemplate); + this.embeddingModel.setMockPredictionServiceClient(this.mockPredictionServiceClient); + this.embeddingModel.setMockPredictRequestBuilder(this.mockPredictRequestBuilder); + when(this.mockPredictRequestBuilder.build()).thenReturn(PredictRequest.getDefaultInstance()); } @Test @@ -130,32 +112,51 @@ public void vertexAiEmbeddingTransientError() { .build(); // Setup the mock PredictionServiceClient - when(mockPredictionServiceClient.predict(any())).thenThrow(new TransientAiException("Transient Error 1")) + when(this.mockPredictionServiceClient.predict(any())).thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(mockResponse); - EmbeddingResponse result = embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null)); + EmbeddingResponse result = this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResults()).hasSize(1); assertThat(result.getResults().get(0).getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); - verify(mockPredictRequestBuilder, times(3)).build(); + verify(this.mockPredictRequestBuilder, times(3)).build(); } @Test public void vertexAiEmbeddingNonTransientError() { // Setup the mock PredictionServiceClient to throw a non-transient error - when(mockPredictionServiceClient.predict(any())).thenThrow(new RuntimeException("Non Transient Error")); + when(this.mockPredictionServiceClient.predict(any())).thenThrow(new RuntimeException("Non Transient Error")); // Assert that a RuntimeException is thrown and not retried assertThrows(RuntimeException.class, - () -> embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null))); + () -> this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null))); // Verify that predict was called only once (no retries for non-transient errors) - verify(mockPredictionServiceClient, times(1)).predict(any()); + verify(this.mockPredictionServiceClient, times(1)).predict(any()); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-vertex-ai-gemini/pom.xml b/models/spring-ai-vertex-ai-gemini/pom.xml index 57f5dcd5ce7..230c5cd6796 100644 --- a/models/spring-ai-vertex-ai-gemini/pom.xml +++ b/models/spring-ai-vertex-ai-gemini/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java index c1d7c34bb5d..fe5e8e52e6e 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini; import java.io.File; @@ -56,21 +57,6 @@ public abstract class MimeTypeDetector { */ private static final Map GEMINI_MIME_TYPES = new HashMap<>(); - static { - // Custom MIME type mappings here - GEMINI_MIME_TYPES.put("png", MimeTypeUtils.IMAGE_PNG); - GEMINI_MIME_TYPES.put("jpeg", MimeTypeUtils.IMAGE_JPEG); - GEMINI_MIME_TYPES.put("jpg", MimeTypeUtils.IMAGE_JPEG); - GEMINI_MIME_TYPES.put("gif", MimeTypeUtils.IMAGE_GIF); - GEMINI_MIME_TYPES.put("mov", new MimeType("video", "mov")); - GEMINI_MIME_TYPES.put("mp4", new MimeType("video", "mp4")); - GEMINI_MIME_TYPES.put("mpg", new MimeType("video", "mpg")); - GEMINI_MIME_TYPES.put("avi", new MimeType("video", "avi")); - GEMINI_MIME_TYPES.put("wmv", new MimeType("video", "wmv")); - GEMINI_MIME_TYPES.put("mpegps", new MimeType("mpegps", "mp4")); - GEMINI_MIME_TYPES.put("flv", new MimeType("video", "flv")); - } - public static MimeType getMimeType(URL url) { return getMimeType(url.getFile()); } @@ -115,4 +101,19 @@ public static MimeType getMimeType(String path) { String.format("Unable to detect the MIME type of '%s'. Please provide it explicitly.", path)); } + static { + // Custom MIME type mappings here + GEMINI_MIME_TYPES.put("png", MimeTypeUtils.IMAGE_PNG); + GEMINI_MIME_TYPES.put("jpeg", MimeTypeUtils.IMAGE_JPEG); + GEMINI_MIME_TYPES.put("jpg", MimeTypeUtils.IMAGE_JPEG); + GEMINI_MIME_TYPES.put("gif", MimeTypeUtils.IMAGE_GIF); + GEMINI_MIME_TYPES.put("mov", new MimeType("video", "mov")); + GEMINI_MIME_TYPES.put("mp4", new MimeType("video", "mp4")); + GEMINI_MIME_TYPES.put("mpg", new MimeType("video", "mpg")); + GEMINI_MIME_TYPES.put("avi", new MimeType("video", "avi")); + GEMINI_MIME_TYPES.put("wmv", new MimeType("video", "wmv")); + GEMINI_MIME_TYPES.put("mpegps", new MimeType("mpegps", "mp4")); + GEMINI_MIME_TYPES.put("flv", new MimeType("video", "flv")); + } + } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index dfe008a47c4..67d6bed36af 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.google.cloud.vertexai.VertexAI; -import com.google.cloud.vertexai.api.*; +import com.google.cloud.vertexai.api.Candidate; import com.google.cloud.vertexai.api.Candidate.FinishReason; +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; +import com.google.cloud.vertexai.api.FunctionDeclaration; +import com.google.cloud.vertexai.api.FunctionResponse; +import com.google.cloud.vertexai.api.GenerateContentResponse; +import com.google.cloud.vertexai.api.GenerationConfig; +import com.google.cloud.vertexai.api.GoogleSearchRetrieval; +import com.google.cloud.vertexai.api.Part; +import com.google.cloud.vertexai.api.Schema; +import com.google.cloud.vertexai.api.Tool; import com.google.cloud.vertexai.generativeai.GenerativeModel; import com.google.cloud.vertexai.generativeai.PartMaker; import com.google.cloud.vertexai.generativeai.ResponseStream; import com.google.protobuf.Struct; import com.google.protobuf.util.JsonFormat; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -60,18 +83,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - /** * @author Christian Tzolov * @author Grogdunn @@ -106,54 +117,6 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - public enum GeminiMessageType { - - USER("user"), - - MODEL("model"); - - GeminiMessageType(String value) { - this.value = value; - } - - public final String value; - - public String getValue() { - return this.value; - } - - } - - public enum ChatModel implements ChatModelDescription { - - /** - * Deprecated by Goolgle in favor of 1.5 pro and flash models. - */ - GEMINI_PRO_VISION("gemini-pro-vision"), - - GEMINI_PRO("gemini-pro"), - - GEMINI_1_5_PRO("gemini-1.5-pro-001"), - - GEMINI_1_5_FLASH("gemini-1.5-flash-001"); - - ChatModel(String value) { - this.value = value; - } - - public final String value; - - public String getValue() { - return this.value; - } - - @Override - public String getName() { - return this.value; - } - - } - public VertexAiGeminiChatModel(VertexAI vertexAI) { this(vertexAI, VertexAiGeminiChatOptions.builder().withModel(ChatModel.GEMINI_1_5_PRO).withTemperature(0.8).build()); @@ -198,6 +161,124 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions opti this.observationRegistry = observationRegistry; } + private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type) { + + Assert.notNull(type, "Message type must not be null"); + + switch (type) { + case SYSTEM: + case USER: + case TOOL: + return GeminiMessageType.USER; + case ASSISTANT: + return GeminiMessageType.MODEL; + default: + throw new IllegalArgumentException("Unsupported message type: " + type); + } + } + + static List messageToGeminiParts(Message message) { + + if (message instanceof SystemMessage systemMessage) { + + List parts = new ArrayList<>(); + + if (systemMessage.getContent() != null) { + parts.add(Part.newBuilder().setText(systemMessage.getContent()).build()); + } + + return parts; + } + else if (message instanceof UserMessage userMessage) { + List parts = new ArrayList<>(); + if (userMessage.getContent() != null) { + parts.add(Part.newBuilder().setText(userMessage.getContent()).build()); + } + + parts.addAll(mediaToParts(userMessage.getMedia())); + + return parts; + } + else if (message instanceof AssistantMessage assistantMessage) { + List parts = new ArrayList<>(); + if (StringUtils.hasText(assistantMessage.getContent())) { + parts.add(Part.newBuilder().setText(assistantMessage.getContent()).build()); + } + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + parts.addAll(assistantMessage.getToolCalls() + .stream() + .map(toolCall -> Part.newBuilder() + .setFunctionCall(FunctionCall.newBuilder() + .setName(toolCall.name()) + .setArgs(jsonToStruct(toolCall.arguments())) + .build()) + .build()) + .toList()); + } + return parts; + } + else if (message instanceof ToolResponseMessage toolResponseMessage) { + + return toolResponseMessage.getResponses() + .stream() + .map(response -> Part.newBuilder() + .setFunctionResponse(FunctionResponse.newBuilder() + .setName(response.name()) + .setResponse(jsonToStruct(response.responseData())) + .build()) + .build()) + .toList(); + } + else { + throw new IllegalArgumentException("Gemini doesn't support message type: " + message.getClass()); + } + } + + private static List mediaToParts(Collection media) { + List parts = new ArrayList<>(); + + List mediaParts = media.stream() + .map(mediaData -> PartMaker.fromMimeTypeAndData(mediaData.getMimeType().toString(), mediaData.getData())) + .toList(); + + if (!CollectionUtils.isEmpty(mediaParts)) { + parts.addAll(mediaParts); + } + + return parts; + } + + private static String structToJson(Struct struct) { + try { + return JsonFormat.printer().print(struct); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static Struct jsonToStruct(String json) { + try { + var structBuilder = Struct.newBuilder(); + JsonFormat.parser().ignoringUnknownFields().merge(json, structBuilder); + return structBuilder.build(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static Schema jsonToSchema(String json) { + try { + var schemaBuilder = Schema.newBuilder(); + JsonFormat.parser().ignoringUnknownFields().merge(json, schemaBuilder); + return schemaBuilder.build(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini @Override public ChatResponse call(Prompt prompt) { @@ -346,10 +427,6 @@ private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse resp return ChatResponseMetadata.builder().withUsage(new VertexAiUsage(response.getUsageMetadata())).build(); } - @JsonInclude(Include.NON_NULL) - public record GeminiRequest(List contents, GenerativeModel model) { - } - private VertexAiGeminiChatOptions vertexAiGeminiChatOptions(Prompt prompt) { VertexAiGeminiChatOptions updatedRuntimeOptions = VertexAiGeminiChatOptions.builder().build(); if (prompt.getOptions() != null) { @@ -480,93 +557,6 @@ private List toGeminiContent(List instrucitons) { return contents; } - private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type) { - - Assert.notNull(type, "Message type must not be null"); - - switch (type) { - case SYSTEM: - case USER: - case TOOL: - return GeminiMessageType.USER; - case ASSISTANT: - return GeminiMessageType.MODEL; - default: - throw new IllegalArgumentException("Unsupported message type: " + type); - } - } - - static List messageToGeminiParts(Message message) { - - if (message instanceof SystemMessage systemMessage) { - - List parts = new ArrayList<>(); - - if (systemMessage.getContent() != null) { - parts.add(Part.newBuilder().setText(systemMessage.getContent()).build()); - } - - return parts; - } - else if (message instanceof UserMessage userMessage) { - List parts = new ArrayList<>(); - if (userMessage.getContent() != null) { - parts.add(Part.newBuilder().setText(userMessage.getContent()).build()); - } - - parts.addAll(mediaToParts(userMessage.getMedia())); - - return parts; - } - else if (message instanceof AssistantMessage assistantMessage) { - List parts = new ArrayList<>(); - if (StringUtils.hasText(assistantMessage.getContent())) { - parts.add(Part.newBuilder().setText(assistantMessage.getContent()).build()); - } - if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { - parts.addAll(assistantMessage.getToolCalls() - .stream() - .map(toolCall -> Part.newBuilder() - .setFunctionCall(FunctionCall.newBuilder() - .setName(toolCall.name()) - .setArgs(jsonToStruct(toolCall.arguments())) - .build()) - .build()) - .toList()); - } - return parts; - } - else if (message instanceof ToolResponseMessage toolResponseMessage) { - - return toolResponseMessage.getResponses() - .stream() - .map(response -> Part.newBuilder() - .setFunctionResponse(FunctionResponse.newBuilder() - .setName(response.name()) - .setResponse(jsonToStruct(response.responseData())) - .build()) - .build()) - .toList(); - } - else { - throw new IllegalArgumentException("Gemini doesn't support message type: " + message.getClass()); - } - } - - private static List mediaToParts(Collection media) { - List parts = new ArrayList<>(); - - List mediaParts = media.stream() - .map(mediaData -> PartMaker.fromMimeTypeAndData(mediaData.getMimeType().toString(), mediaData.getData())) - .toList(); - - if (!CollectionUtils.isEmpty(mediaParts)) { - parts.addAll(mediaParts); - } - - return parts; - } - private List getFunctionTools(Set functionNames) { final var tool = Tool.newBuilder(); @@ -583,37 +573,6 @@ private List getFunctionTools(Set functionNames) { return List.of(tool.build()); } - private static String structToJson(Struct struct) { - try { - return JsonFormat.printer().print(struct); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - - private static Struct jsonToStruct(String json) { - try { - var structBuilder = Struct.newBuilder(); - JsonFormat.parser().ignoringUnknownFields().merge(json, structBuilder); - return structBuilder.build(); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - - private static Schema jsonToSchema(String json) { - try { - var schemaBuilder = Schema.newBuilder(); - JsonFormat.parser().ignoringUnknownFields().merge(json, schemaBuilder); - return schemaBuilder.build(); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - /** * Generates the content response based on the provided Gemini request. Package * protected for testing purposes. @@ -651,4 +610,57 @@ public void setObservationConvention(ChatModelObservationConvention observationC this.observationConvention = observationConvention; } + public enum GeminiMessageType { + + USER("user"), + + MODEL("model"); + + public final String value; + + GeminiMessageType(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + public enum ChatModel implements ChatModelDescription { + + /** + * Deprecated by Goolgle in favor of 1.5 pro and flash models. + */ + GEMINI_PRO_VISION("gemini-pro-vision"), + + GEMINI_PRO("gemini-pro"), + + GEMINI_1_5_PRO("gemini-1.5-pro-001"), + + GEMINI_1_5_FLASH("gemini-1.5-flash-001"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + @Override + public String getName() { + return this.value; + } + + } + + @JsonInclude(Include.NON_NULL) + public record GeminiRequest(List contents, GenerativeModel model) { + + } + } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 9088916457b..574574a75eb 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini; import java.util.ArrayList; @@ -45,41 +46,43 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions, ChatOp // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig - public enum TransportType { - - GRPC, REST - - } - - // @formatter:off /** * Optional. Stop sequences. */ private @JsonProperty("stopSequences") List stopSequences; + + // @formatter:off + /** * Optional. Controls the randomness of predictions. */ private @JsonProperty("temperature") Double temperature; + /** * Optional. If specified, nucleus sampling will be used. */ private @JsonProperty("topP") Double topP; + /** * Optional. If specified, top k sampling will be used. */ private @JsonProperty("topK") Float topK; + /** * Optional. The maximum number of tokens to generate. */ private @JsonProperty("candidateCount") Integer candidateCount; + /** * Optional. The maximum number of tokens to generate. */ private @JsonProperty("maxOutputTokens") Integer maxOutputTokens; + /** * Gemini model name. */ private @JsonProperty("modelName") String model; + /** * Optional. Output response mimetype of the generated candidate text. * - text/plain: (default) Text output. @@ -123,103 +126,29 @@ public enum TransportType { @JsonIgnore private Map toolContext; - // @formatter:on - public static Builder builder() { return new Builder(); } - public static class Builder { - - private VertexAiGeminiChatOptions options = new VertexAiGeminiChatOptions(); - - public Builder withStopSequences(List stopSequences) { - this.options.setStopSequences(stopSequences); - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withTopK(Float topK) { - this.options.setTopK(topK); - return this; - } - - public Builder withCandidateCount(Integer candidateCount) { - this.options.setCandidateCount(candidateCount); - return this; - } - - public Builder withMaxOutputTokens(Integer maxOutputTokens) { - this.options.setMaxOutputTokens(maxOutputTokens); - return this; - } - - public Builder withModel(String modelName) { - this.options.setModel(modelName); - return this; - } - - public Builder withModel(ChatModel model) { - this.options.setModel(model.getValue()); - return this; - } - - public Builder withResponseMimeType(String mimeType) { - Assert.notNull(mimeType, "mimeType must not be null"); - this.options.setResponseMimeType(mimeType); - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withGoogleSearchRetrieval(boolean googleSearch) { - this.options.googleSearchRetrieval = googleSearch; - return this; - } - - public Builder withProxyToolCalls(boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public VertexAiGeminiChatOptions build() { - return this.options; - } + // @formatter:on + public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fromOptions) { + VertexAiGeminiChatOptions options = new VertexAiGeminiChatOptions(); + options.setStopSequences(fromOptions.getStopSequences()); + options.setTemperature(fromOptions.getTemperature()); + options.setTopP(fromOptions.getTopP()); + options.setTopK(fromOptions.getTopK()); + options.setCandidateCount(fromOptions.getCandidateCount()); + options.setMaxOutputTokens(fromOptions.getMaxOutputTokens()); + options.setModel(fromOptions.getModel()); + options.setFunctionCallbacks(fromOptions.getFunctionCallbacks()); + options.setResponseMimeType(fromOptions.getResponseMimeType()); + options.setFunctions(fromOptions.getFunctions()); + options.setResponseMimeType(fromOptions.getResponseMimeType()); + options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); + options.setProxyToolCalls(fromOptions.getProxyToolCalls()); + options.setToolContext(fromOptions.getToolContext()); + return options; } @Override @@ -364,33 +293,39 @@ public void setToolContext(Map toolContext) { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof VertexAiGeminiChatOptions that)) + } + if (!(o instanceof VertexAiGeminiChatOptions that)) { return false; - return googleSearchRetrieval == that.googleSearchRetrieval && Objects.equals(stopSequences, that.stopSequences) - && Objects.equals(temperature, that.temperature) && Objects.equals(topP, that.topP) - && Objects.equals(topK, that.topK) && Objects.equals(candidateCount, that.candidateCount) - && Objects.equals(maxOutputTokens, that.maxOutputTokens) && Objects.equals(model, that.model) - && Objects.equals(responseMimeType, that.responseMimeType) - && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(functions, that.functions) && Objects.equals(proxyToolCalls, that.proxyToolCalls) - && Objects.equals(toolContext, that.toolContext); + } + return this.googleSearchRetrieval == that.googleSearchRetrieval + && Objects.equals(this.stopSequences, that.stopSequences) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) + && Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount) + && Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model) + && Objects.equals(this.responseMimeType, that.responseMimeType) + && Objects.equals(this.functionCallbacks, that.functionCallbacks) + && Objects.equals(this.functions, that.functions) + && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) + && Objects.equals(this.toolContext, that.toolContext); } @Override public int hashCode() { - return Objects.hash(stopSequences, temperature, topP, topK, candidateCount, maxOutputTokens, model, - responseMimeType, functionCallbacks, functions, googleSearchRetrieval, proxyToolCalls, toolContext); + return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, + this.maxOutputTokens, this.model, this.responseMimeType, this.functionCallbacks, this.functions, + this.googleSearchRetrieval, this.proxyToolCalls, this.toolContext); } @Override public String toString() { - return "VertexAiGeminiChatOptions{" + "stopSequences=" + stopSequences + ", temperature=" + temperature - + ", topP=" + topP + ", topK=" + topK + ", candidateCount=" + candidateCount + ", maxOutputTokens=" - + maxOutputTokens + ", model='" + model + '\'' + ", responseMimeType='" + responseMimeType + '\'' - + ", functionCallbacks=" + functionCallbacks + ", functions=" + functions + ", googleSearchRetrieval=" - + googleSearchRetrieval + '}'; + return "VertexAiGeminiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature=" + + this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", candidateCount=" + + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + + ", responseMimeType='" + this.responseMimeType + '\'' + ", functionCallbacks=" + + this.functionCallbacks + ", functions=" + this.functions + ", googleSearchRetrieval=" + + this.googleSearchRetrieval + '}'; } @Override @@ -398,23 +333,103 @@ public VertexAiGeminiChatOptions copy() { return fromOptions(this); } - public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fromOptions) { - VertexAiGeminiChatOptions options = new VertexAiGeminiChatOptions(); - options.setStopSequences(fromOptions.getStopSequences()); - options.setTemperature(fromOptions.getTemperature()); - options.setTopP(fromOptions.getTopP()); - options.setTopK(fromOptions.getTopK()); - options.setCandidateCount(fromOptions.getCandidateCount()); - options.setMaxOutputTokens(fromOptions.getMaxOutputTokens()); - options.setModel(fromOptions.getModel()); - options.setFunctionCallbacks(fromOptions.getFunctionCallbacks()); - options.setResponseMimeType(fromOptions.getResponseMimeType()); - options.setFunctions(fromOptions.getFunctions()); - options.setResponseMimeType(fromOptions.getResponseMimeType()); - options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); - options.setProxyToolCalls(fromOptions.getProxyToolCalls()); - options.setToolContext(fromOptions.getToolContext()); - return options; + public enum TransportType { + + GRPC, REST + + } + + public static class Builder { + + private VertexAiGeminiChatOptions options = new VertexAiGeminiChatOptions(); + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withTopK(Float topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withCandidateCount(Integer candidateCount) { + this.options.setCandidateCount(candidateCount); + return this; + } + + public Builder withMaxOutputTokens(Integer maxOutputTokens) { + this.options.setMaxOutputTokens(maxOutputTokens); + return this; + } + + public Builder withModel(String modelName) { + this.options.setModel(modelName); + return this; + } + + public Builder withModel(ChatModel model) { + this.options.setModel(model.getValue()); + return this; + } + + public Builder withResponseMimeType(String mimeType) { + Assert.notNull(mimeType, "mimeType must not be null"); + this.options.setResponseMimeType(mimeType); + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withGoogleSearchRetrieval(boolean googleSearch) { + this.options.googleSearchRetrieval = googleSearch; + return this; + } + + public Builder withProxyToolCalls(boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public VertexAiGeminiChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java index 0a46b9f2fa5..fd3d04106c4 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini.aot; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java index 4369a2e7992..2d8b69f9861 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java index a250b98e0c1..aeab57df751 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini.metadata; import com.google.cloud.vertexai.api.GenerateContentResponse.UsageMetadata; @@ -23,7 +24,7 @@ /** * @author Christian Tzolov * @since 0.8.1 - * + * */ public class VertexAiUsage implements Usage { @@ -36,12 +37,12 @@ public VertexAiUsage(UsageMetadata usageMetadata) { @Override public Long getPromptTokens() { - return Long.valueOf(usageMetadata.getPromptTokenCount()); + return Long.valueOf(this.usageMetadata.getPromptTokenCount()); } @Override public Long getGenerationTokens() { - return Long.valueOf(usageMetadata.getCandidatesTokenCount()); + return Long.valueOf(this.usageMetadata.getCandidatesTokenCount()); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index acbe2672f01..3925fe12506 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,30 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.gemini; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vertexai.gemini; import java.net.MalformedURLException; import java.net.URL; import java.util.List; +import com.google.cloud.vertexai.VertexAI; +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.Part; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.model.Media; + import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.GeminiRequest; +import org.springframework.ai.vertexai.gemini.function.MockWeatherService; import org.springframework.util.MimeTypeUtils; -import com.google.cloud.vertexai.VertexAI; -import com.google.cloud.vertexai.api.Content; -import com.google.cloud.vertexai.api.Part; -import org.springframework.ai.vertexai.gemini.function.MockWeatherService; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -50,7 +51,7 @@ public class CreateGeminiRequestTests { @Test public void createRequestWithChatOptions() { - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build()); GeminiRequest request = client.createGeminiRequest(new Prompt("Test message content"), null); @@ -81,7 +82,7 @@ public void createRequestWithSystemMessage() throws MalformedURLException { var userMessage = new UserMessage("User Message Text", List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("http://example.com")))); - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build()); GeminiRequest request = client.createGeminiRequest(new Prompt(List.of(systemMessage, userMessage)), null); @@ -110,7 +111,7 @@ public void promptOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder().withModel("DEFAULT_MODEL").build()); var request = client.createGeminiRequest(new Prompt("Test message content", @@ -141,7 +142,7 @@ public void defaultOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder() .withModel("DEFAULT_MODEL") .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) @@ -198,7 +199,7 @@ public void defaultOptionsTools() { @Test public void createRequestWithGenerationConfigOptions() { - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder() .withModel("DEFAULT_MODEL") .withTemperature(66.6) diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java index c7db75aab0a..9ab82aa64b7 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,16 +16,17 @@ package org.springframework.ai.vertexai.gemini; +import java.io.IOException; +import java.util.List; + import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.generativeai.GenerativeModel; + import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.retry.support.RetryTemplate; -import java.io.IOException; -import java.util.List; - /** * @author Mark Pollack */ @@ -41,9 +42,9 @@ public TestVertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions @Override GenerateContentResponse getContentResponse(GeminiRequest request) { - if (mockGenerativeModel != null) { + if (this.mockGenerativeModel != null) { try { - return mockGenerativeModel.generateContent(request.contents()); + return this.mockGenerativeModel.generateContent(request.contents()); } catch (IOException e) { // Should not be thrown by testing class diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java index 093dd3ba994..e34963c5971 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,17 @@ package org.springframework.ai.vertexai.gemini; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.List; import java.util.stream.Collectors; +import com.google.cloud.vertexai.Transport; +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; @@ -38,11 +41,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import com.google.cloud.vertexai.Transport; -import com.google.cloud.vertexai.VertexAI; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Soby Chacko @@ -60,7 +59,7 @@ public class VertexAiChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -76,7 +75,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -98,7 +97,7 @@ void observationForStreamingOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponse = chatModel.stream(prompt); + Flux chatResponse = this.chatModel.stream(prompt); List responses = chatResponse.collectList().block(); assertThat(responses).isNotEmpty(); assertThat(responses).hasSizeGreaterThan(1); @@ -118,7 +117,7 @@ void observationForStreamingOperation() { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java index b11c8f1ec9c..4d6a064448c 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini; import java.io.IOException; @@ -28,18 +29,18 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -66,19 +67,19 @@ class VertexAiGeminiChatModelIT { @Test void roleTest() { Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @Test void testMessageHistory() { Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), prompt.getInstructions().get(1), response.getResult().getOutput(), new UserMessage("Repeat the last assistant message."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @@ -86,7 +87,7 @@ void testMessageHistory() { @Test void googleSearchTool() { Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().withGoogleSearchRetrieval(true).build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @@ -96,7 +97,7 @@ private Prompt createPrompt(VertexAiGeminiChatOptions chatOptions) { String name = "Bob"; String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), chatOptions); return prompt; @@ -133,16 +134,13 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -156,7 +154,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConvert.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -166,7 +164,8 @@ void beanOutputConverterRecords() { @Test void textStream() { - String generationTextFromStream = chatModel.stream(new Prompt("Explain Bulgaria? Answer in 10 paragraphs.")) + String generationTextFromStream = this.chatModel + .stream(new Prompt("Explain Bulgaria? Answer in 10 paragraphs.")) .collectList() .block() .stream() @@ -194,7 +193,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -218,7 +217,7 @@ void multiModalityTest() throws IOException { var userMessage = new UserMessage("Explain what do you see o this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, data))); - var response = chatModel.call(new Prompt(List.of(userMessage))); + var response = this.chatModel.call(new Prompt(List.of(userMessage))); // Response should contain something like: // I see a bunch of bananas in a golden basket. The bananas are ripe and yellow. @@ -247,6 +246,10 @@ void multiModalityTest() throws IOException { // https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/intro_multimodal_use_cases.ipynb } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java index 387d2550161..bddb9328a28 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,10 +16,9 @@ package org.springframework.ai.vertexai.gemini; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.*; +import java.io.IOException; +import java.util.Collections; +import java.util.List; import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.Candidate; @@ -32,6 +31,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.retry.RetryUtils; @@ -41,11 +41,10 @@ import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import java.io.IOException; -import java.util.Collections; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.when; /** * @author Mark Pollack @@ -54,25 +53,6 @@ @ExtendWith(MockitoExtension.class) public class VertexAiGeminiRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -87,19 +67,19 @@ public void onError(RetryContext context, RetryCallback @BeforeEach public void setUp() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); - chatModel = new TestVertexAiGeminiChatModel(vertexAI, + this.chatModel = new TestVertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder() .withTemperature(0.7) .withTopP(1.0) .withModel(VertexAiGeminiChatModel.ChatModel.GEMINI_PRO.getValue()) .build(), - null, Collections.emptyList(), retryTemplate); + null, Collections.emptyList(), this.retryTemplate); - chatModel.setMockGenerativeModel(mockGenerativeModel); + this.chatModel.setMockGenerativeModel(this.mockGenerativeModel); } @Test @@ -111,29 +91,48 @@ public void vertexAiGeminiChatTransientError() throws IOException { .build()) .build(); - when(mockGenerativeModel.generateContent(any(List.class))) + when(this.mockGenerativeModel.generateContent(any(List.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(mockedResponse); // Call the chat model - ChatResponse result = chatModel.call(new Prompt("test prompt")); + ChatResponse result = this.chatModel.call(new Prompt("test prompt")); // Assertions assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isEqualTo("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void vertexAiGeminiChatNonTransientError() throws Exception { // Set up the mock GenerativeModel to throw a non-transient RuntimeException - when(mockGenerativeModel.generateContent(any(List.class))) + when(this.mockGenerativeModel.generateContent(any(List.class))) .thenThrow(new RuntimeException("Non Transient Error")); // Assert that a RuntimeException is thrown when calling the chat model - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("test prompt"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("test prompt"))); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java index a4aaf39884e..88774e9e356 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini.aot; import java.util.Set; diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java index ff62411a98f..a7f7521df5a 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini.function; import java.util.function.Function; @@ -32,14 +33,22 @@ public class MockWeatherService implements Function response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String responseString = response.collectList() .block() @@ -203,16 +203,18 @@ public void functionCallTestInferredOpenApiSchemaStream() { .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", responseString); + this.logger.info("Response: {}", responseString); assertThat(responseString).contains("30", "10", "15"); } public record PaymentInfoRequest(String id) { + } public record TransactionStatus(String status) { + } public static class PaymentStatus implements Function { diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java index 7f6ec1d3ad2..4d674c8bd5b 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,23 +16,25 @@ package org.springframework.ai.vertexai.gemini.function; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; +import com.google.cloud.vertexai.Transport; +import com.google.cloud.vertexai.VertexAI; import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; @@ -44,10 +46,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Description; -import com.google.cloud.vertexai.Transport; -import com.google.cloud.vertexai.VertexAI; - -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -59,51 +58,12 @@ public class VertexAiGeminiPaymentTransactionIT { private final static Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionIT.class); + private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), + new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); + @Autowired ChatClient chatClient; - record TransactionStatusResponse(String id, String status) { - } - - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - @Override - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - var response = chain.nextAroundCall(before(advisedRequest)); - observeAfter(response); - return response; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - @Test public void paymentStatuses() { // @formatter:off @@ -149,6 +109,49 @@ public void streamingPaymentStatuses() { } } + record TransactionStatusResponse(String id, String status) { + + } + + private static class LoggingAdvisor implements CallAroundAdvisor { + + private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + var response = chain.nextAroundCall(before(advisedRequest)); + observeAfter(response); + return response; + } + + private AdvisedRequest before(AdvisedRequest request) { + this.logger.info("System text: \n" + request.systemText()); + this.logger.info("System params: " + request.systemParams()); + this.logger.info("User text: \n" + request.userText()); + this.logger.info("User params:" + request.userParams()); + this.logger.info("Function names: " + request.functionNames()); + + this.logger.info("Options: " + request.chatOptions().toString()); + + return request; + } + + private void observeAfter(AdvisedResponse advisedResponse) { + this.logger.info("Response: " + advisedResponse.response()); + } + + } + record Transaction(String id) { } @@ -161,9 +164,6 @@ record Transactions(List transactions) { record Statuses(List statuses) { } - private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), - new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); - @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-vertex-ai-palm2/pom.xml b/models/spring-ai-vertex-ai-palm2/pom.xml index 2c5123b3c2f..07455d51dd4 100644 --- a/models/spring-ai-vertex-ai-palm2/pom.xml +++ b/models/spring-ai-vertex-ai-palm2/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatModel.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatModel.java index de8d732636e..b68e054a47c 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatModel.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,22 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import java.util.List; import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageRequest; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageResponse; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.MessagePrompt; -import org.springframework.ai.chat.messages.MessageType; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java index 77b968f6936..34a63ccb7b3 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -22,8 +25,6 @@ import org.springframework.ai.chat.prompt.ChatOptions; -import java.util.List; - /** * @author Christian Tzolov * @author Thomas Vitale @@ -66,34 +67,13 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private VertexAiPaLm2ChatOptions options = new VertexAiPaLm2ChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withCandidateCount(Integer candidateCount) { - this.options.candidateCount = candidateCount; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withTopK(Integer topK) { - this.options.topK = topK; - return this; - } - - public VertexAiPaLm2ChatOptions build() { - return this.options; - } - + public static VertexAiPaLm2ChatOptions fromOptions(VertexAiPaLm2ChatOptions fromOptions) { + return VertexAiPaLm2ChatOptions.builder() + .withTemperature(fromOptions.getTemperature()) + .withCandidateCount(fromOptions.getCandidateCount()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .build(); } @Override @@ -166,13 +146,34 @@ public VertexAiPaLm2ChatOptions copy() { return fromOptions(this); } - public static VertexAiPaLm2ChatOptions fromOptions(VertexAiPaLm2ChatOptions fromOptions) { - return VertexAiPaLm2ChatOptions.builder() - .withTemperature(fromOptions.getTemperature()) - .withCandidateCount(fromOptions.getCandidateCount()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .build(); + public static class Builder { + + private VertexAiPaLm2ChatOptions options = new VertexAiPaLm2ChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withCandidateCount(Integer candidateCount) { + this.options.candidateCount = candidateCount; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.options.topK = topK; + return this; + } + + public VertexAiPaLm2ChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModel.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModel.java index 3fcdb935720..3ece4c0bc33 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModel.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import java.util.List; diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java index 65a8952c643..8e9ea670788 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.aot; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2Api.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2Api.java index 2ad7c827bf3..e1e46326dcd 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2Api.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2Api.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.api; import java.io.IOException; @@ -207,10 +208,6 @@ record EmbeddingResponse(Embedding embedding) { return response != null ? response.embedding() : null; } - @JsonInclude(Include.NON_NULL) - record BatchEmbeddingResponse(List embeddings) { - } - /** * Generates a response from the model given an input. * @param texts List of texts to embed. @@ -294,6 +291,10 @@ public Model getModel(String modelName) { .body(Model.class); } + @JsonInclude(Include.NON_NULL) + record BatchEmbeddingResponse(List embeddings) { + } + /** * API error response. * @@ -375,12 +376,12 @@ public record Embedding( @Override public final int hashCode() { - return Arrays.hashCode(value); + return Arrays.hashCode(this.value); } @Override public final boolean equals(Object arg0) { - return Arrays.equals(value,((Embedding) arg0).value); + return Arrays.equals(this.value,((Embedding) arg0).value); } } diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatGenerationClientIT.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatGenerationClientIT.java index 3a98497cb10..21d07a608a2 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatGenerationClientIT.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatGenerationClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import java.util.Arrays; @@ -22,10 +23,10 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -59,10 +60,10 @@ void roleTest() { String name = "Bob"; String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Bartholomew"); } @@ -98,16 +99,13 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - // @Test void beanOutputConverterRecords() { @@ -120,13 +118,17 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatRequestTests.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatRequestTests.java index 9e2d6726cf6..d2dc2802459 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatRequestTests.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import org.junit.jupiter.api.Test; @@ -34,7 +35,7 @@ public class VertexAiPaLm2ChatRequestTests { @Test public void createRequestWithDefaultOptions() { - var request = chatModel.createRequest(new Prompt("Test message content")); + var request = this.chatModel.createRequest(new Prompt("Test message content")); assertThat(request.prompt().messages()).hasSize(1); @@ -55,7 +56,7 @@ public void createRequestWithPromptVertexAiOptions() { // .withCandidateCount(2) .build(); - var request = chatModel.createRequest(new Prompt("Test message content", promptOptions)); + var request = this.chatModel.createRequest(new Prompt("Test message content", promptOptions)); assertThat(request.prompt().messages()).hasSize(1); @@ -75,7 +76,7 @@ public void createRequestWithPromptPortableChatOptions() { .withTopP(0.6) .build(); - var request = chatModel.createRequest(new Prompt("Test message content", portablePromptOptions)); + var request = this.chatModel.createRequest(new Prompt("Test message content", portablePromptOptions)); assertThat(request.prompt().messages()).hasSize(1); diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModelIT.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModelIT.java index 2e05dbfd12d..221964fd12f 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import java.util.List; @@ -38,17 +39,17 @@ class VertexAiPaLm2EmbeddingModelIT { @Test void simpleEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); @@ -56,7 +57,7 @@ void batchEmbedding() { assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @SpringBootConfiguration diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHintsTests.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHintsTests.java index 5ed21d4649b..ae57c09a3a4 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHintsTests.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiIT.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiIT.java index d820d49521a..40f854df0f0 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiIT.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.api; import java.util.List; @@ -47,7 +48,7 @@ public void generateMessage() { GenerateMessageRequest request = new GenerateMessageRequest(prompt); - GenerateMessageResponse response = vertexAiPaLm2Api.generateMessage(request); + GenerateMessageResponse response = this.vertexAiPaLm2Api.generateMessage(request); assertThat(response).isNotNull(); @@ -66,7 +67,7 @@ public void embedText() { var text = "Hello, how are you?"; - Embedding response = vertexAiPaLm2Api.embedText(text); + Embedding response = this.vertexAiPaLm2Api.embedText(text); assertThat(response).isNotNull(); assertThat(response.value()).hasSize(768); @@ -77,7 +78,7 @@ public void batchEmbedText() { var text = List.of("Hello, how are you?", "I am fine, thank you!"); - List response = vertexAiPaLm2Api.batchEmbedText(text); + List response = this.vertexAiPaLm2Api.batchEmbedText(text); assertThat(response).isNotNull(); assertThat(response).hasSize(2); @@ -91,7 +92,7 @@ public void countMessageTokens() { var text = "Hello, how are you?"; var prompt = new MessagePrompt(List.of(new VertexAiPaLm2Api.Message("0", text))); - int response = vertexAiPaLm2Api.countMessageTokens(prompt); + int response = this.vertexAiPaLm2Api.countMessageTokens(prompt); assertThat(response).isEqualTo(17); } @@ -99,14 +100,14 @@ public void countMessageTokens() { @Test public void listModels() { - List response = vertexAiPaLm2Api.listModels(); + List response = this.vertexAiPaLm2Api.listModels(); assertThat(response).isNotNull(); assertThat(response).hasSizeGreaterThan(0); assertThat(response).contains("models/chat-bison-001", "models/text-bison-001", "models/embedding-gecko-001"); System.out.println(" - " + response.stream() - .map(vertexAiPaLm2Api::getModel) + .map(this.vertexAiPaLm2Api::getModel) .map(VertexAiPaLm2Api.Model::toString) .collect(Collectors.joining("\n - "))); } @@ -114,7 +115,7 @@ public void listModels() { @Test public void getModel() { - VertexAiPaLm2Api.Model model = vertexAiPaLm2Api.getModel("models/chat-bison-001"); + VertexAiPaLm2Api.Model model = this.vertexAiPaLm2Api.getModel("models/chat-bison-001"); System.out.println(model); assertThat(model).isNotNull(); diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiTests.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiTests.java index 2f58d9dd6a3..77bfe39a490 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiTests.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.api; import java.util.List; @@ -26,8 +27,8 @@ import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.Embedding; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageRequest; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageResponse; -import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.MessagePrompt; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageResponse.ContentFilter.BlockedReason; +import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.MessagePrompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; @@ -62,7 +63,7 @@ public class VertexAiPaLm2ApiTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -76,18 +77,19 @@ public void generateMessage() throws JsonProcessingException { List.of(new VertexAiPaLm2Api.Message("0", "I'm fine, thank you.")), List.of(new VertexAiPaLm2Api.GenerateMessageResponse.ContentFilter(BlockedReason.SAFETY, "reason"))); - server + this.server .expect(requestToUriTemplate("/models/{generative}:generateMessage?key={apiKey}", VertexAiPaLm2Api.DEFAULT_GENERATE_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) - .andExpect(content().json(objectMapper.writeValueAsString(request))) - .andRespond(withSuccess(objectMapper.writeValueAsString(expectedResponse), MediaType.APPLICATION_JSON)); + .andExpect(content().json(this.objectMapper.writeValueAsString(request))) + .andRespond( + withSuccess(this.objectMapper.writeValueAsString(expectedResponse), MediaType.APPLICATION_JSON)); - GenerateMessageResponse response = client.generateMessage(request); + GenerateMessageResponse response = this.client.generateMessage(request); assertThat(response).isEqualTo(expectedResponse); - server.verify(); + this.server.verify(); } @Test @@ -97,19 +99,19 @@ public void embedText() throws JsonProcessingException { Embedding expectedEmbedding = new Embedding(new float[] { 0.1f, 0.2f, 0.3f }); - server + this.server .expect(requestToUriTemplate("/models/{generative}:embedText?key={apiKey}", VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) - .andExpect(content().json(objectMapper.writeValueAsString(Map.of("text", text)))) - .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embedding", expectedEmbedding)), + .andExpect(content().json(this.objectMapper.writeValueAsString(Map.of("text", text)))) + .andRespond(withSuccess(this.objectMapper.writeValueAsString(Map.of("embedding", expectedEmbedding)), MediaType.APPLICATION_JSON)); - Embedding embedding = client.embedText(text); + Embedding embedding = this.client.embedText(text); assertThat(embedding).isEqualTo(expectedEmbedding); - server.verify(); + this.server.verify(); } @Test @@ -120,19 +122,19 @@ public void batchEmbedText() throws JsonProcessingException { List expectedEmbeddings = List.of(new Embedding(new float[] { 0.1f, 0.2f, 0.3f }), new Embedding(new float[] { 0.4f, 0.5f, 0.6f })); - server + this.server .expect(requestToUriTemplate("/models/{generative}:batchEmbedText?key={apiKey}", VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) - .andExpect(content().json(objectMapper.writeValueAsString(Map.of("texts", texts)))) - .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embeddings", expectedEmbeddings)), + .andExpect(content().json(this.objectMapper.writeValueAsString(Map.of("texts", texts)))) + .andRespond(withSuccess(this.objectMapper.writeValueAsString(Map.of("embeddings", expectedEmbeddings)), MediaType.APPLICATION_JSON)); - List embeddings = client.batchEmbedText(texts); + List embeddings = this.client.batchEmbedText(texts); assertThat(embeddings).isEqualTo(expectedEmbeddings); - server.verify(); + this.server.verify(); } @SpringBootConfiguration diff --git a/models/spring-ai-watsonx-ai/pom.xml b/models/spring-ai-watsonx-ai/pom.xml index e5fc7df00d8..3fc195eb470 100644 --- a/models/spring-ai-watsonx-ai/pom.xml +++ b/models/spring-ai-watsonx-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java index 218ce46976e..8b4e78b34ec 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx; import java.util.List; import java.util.Map; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.model.ChatModel; import reactor.core.publisher.Flux; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; @@ -145,4 +146,4 @@ public ChatOptions getDefaultOptions() { return WatsonxAiChatOptions.fromOptions(this.defaultOptions); } -} \ No newline at end of file +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java index e32ba2d1b02..9a113da5795 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx; import java.util.HashMap; @@ -20,13 +21,14 @@ import java.util.Map; import java.util.stream.Collectors; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonAnyGetter; import com.fasterxml.jackson.annotation.JsonAnySetter; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; + import org.springframework.ai.chat.prompt.ChatOptions; /** @@ -44,6 +46,9 @@ public class WatsonxAiChatOptions implements ChatOptions { + @JsonIgnore + private final ObjectMapper mapper = new ObjectMapper(); + /** * The temperature of the model. Increasing the temperature will * make the model answer more creatively. (Default: 0.7) @@ -122,12 +127,41 @@ public class WatsonxAiChatOptions implements ChatOptions { @JsonProperty("additional") private Map additional = new HashMap<>(); - @JsonIgnore - private final ObjectMapper mapper = new ObjectMapper(); + public static Builder builder() { + return new Builder(); + } + + /** + * Filter out the non-supported fields from the options. + * @param options The options to filter. + * @return The filtered options. + */ + public static Map filterNonSupportedFields(Map options) { + return options.entrySet().stream() + .filter(e -> !e.getKey().equals("model")) + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) { + return WatsonxAiChatOptions.builder() + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withDecodingMethod(fromOptions.getDecodingMethod()) + .withMaxNewTokens(fromOptions.getMaxNewTokens()) + .withMinNewTokens(fromOptions.getMinNewTokens()) + .withStopSequences(fromOptions.getStopSequences()) + .withRepetitionPenalty(fromOptions.getRepetitionPenalty()) + .withRandomSeed(fromOptions.getRandomSeed()) + .withModel(fromOptions.getModel()) + .withAdditionalProperties(fromOptions.getAdditionalProperties()) + .build(); + } @Override public Double getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Double temperature) { @@ -136,7 +170,7 @@ public void setTemperature(Double temperature) { @Override public Double getTopP() { - return topP; + return this.topP; } public void setTopP(Double topP) { @@ -145,7 +179,7 @@ public void setTopP(Double topP) { @Override public Integer getTopK() { - return topK; + return this.topK; } public void setTopK(Integer topK) { @@ -153,7 +187,7 @@ public void setTopK(Integer topK) { } public String getDecodingMethod() { - return decodingMethod; + return this.decodingMethod; } public void setDecodingMethod(String decodingMethod) { @@ -172,7 +206,7 @@ public void setMaxTokens(Integer maxTokens) { } public Integer getMaxNewTokens() { - return maxNewTokens; + return this.maxNewTokens; } public void setMaxNewTokens(Integer maxNewTokens) { @@ -180,7 +214,7 @@ public void setMaxNewTokens(Integer maxNewTokens) { } public Integer getMinNewTokens() { - return minNewTokens; + return this.minNewTokens; } public void setMinNewTokens(Integer minNewTokens) { @@ -189,7 +223,7 @@ public void setMinNewTokens(Integer minNewTokens) { @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } public void setStopSequences(List stopSequences) { @@ -208,7 +242,7 @@ public void setPresencePenalty(Double presencePenalty) { } public Double getRepetitionPenalty() { - return repetitionPenalty; + return this.repetitionPenalty; } public void setRepetitionPenalty(Double repetitionPenalty) { @@ -216,7 +250,7 @@ public void setRepetitionPenalty(Double repetitionPenalty) { } public Integer getRandomSeed() { - return randomSeed; + return this.randomSeed; } public void setRandomSeed(Integer randomSeed) { @@ -225,7 +259,7 @@ public void setRandomSeed(Integer randomSeed) { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -234,7 +268,7 @@ public void setModel(String model) { @JsonAnyGetter public Map getAdditionalProperties() { - return additional.entrySet().stream() + return this.additional.entrySet().stream() .collect(Collectors.toMap( entry -> toSnakeCase(entry.getKey()), Map.Entry::getValue @@ -243,7 +277,7 @@ public Map getAdditionalProperties() { @JsonAnySetter public void addAdditionalProperty(String key, Object value) { - additional.put(key, value); + this.additional.put(key, value); } @Override @@ -252,9 +286,31 @@ public Double getFrequencyPenalty() { return null; } - public static Builder builder() { - return new Builder(); - } + /** + * Convert the {@link WatsonxAiChatOptions} object to a {@link Map} of key/value pairs. + * @return The {@link Map} of key/value pairs. + */ + public Map toMap() { + try { + var json = this.mapper.writeValueAsString(this); + var map = this.mapper.readValue(json, new TypeReference>() {}); + map.remove("additional"); + + return map; + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private String toSnakeCase(String input) { + return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null; + } + + @Override + public WatsonxAiChatOptions copy() { + return fromOptions(this); + } public static class Builder { @@ -325,59 +381,5 @@ public WatsonxAiChatOptions build() { } } - /** - * Convert the {@link WatsonxAiChatOptions} object to a {@link Map} of key/value pairs. - * @return The {@link Map} of key/value pairs. - */ - public Map toMap() { - try { - var json = mapper.writeValueAsString(this); - var map = mapper.readValue(json, new TypeReference>() {}); - map.remove("additional"); - - return map; - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Filter out the non-supported fields from the options. - * @param options The options to filter. - * @return The filtered options. - */ - public static Map filterNonSupportedFields(Map options) { - return options.entrySet().stream() - .filter(e -> !e.getKey().equals("model")) - .filter(e -> e.getValue() != null) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } - - private String toSnakeCase(String input) { - return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null; - } - - @Override - public WatsonxAiChatOptions copy() { - return fromOptions(this); - } - - public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) { - return WatsonxAiChatOptions.builder() - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withDecodingMethod(fromOptions.getDecodingMethod()) - .withMaxNewTokens(fromOptions.getMaxNewTokens()) - .withMinNewTokens(fromOptions.getMinNewTokens()) - .withStopSequences(fromOptions.getStopSequences()) - .withRepetitionPenalty(fromOptions.getRepetitionPenalty()) - .withRandomSeed(fromOptions.getRandomSeed()) - .withModel(fromOptions.getModel()) - .withAdditionalProperties(fromOptions.getAdditionalProperties()) - .build(); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java index 5b3e03ea1dc..18e3ae3617a 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java @@ -1,18 +1,40 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.*; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.watsonx.api.WatsonxAiApi; import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingRequest; import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResponse; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; - /** * {@link EmbeddingModel} implementation for {@literal Watsonx.ai}. *

diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java index 9db6b6dd517..ab1622feaf1 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java @@ -1,8 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -22,13 +39,25 @@ public class WatsonxAiEmbeddingOptions implements EmbeddingOptions { @JsonProperty("model_id") private String model; + /** + * Helper factory method to create a new {@link WatsonxAiEmbeddingOptions} instance. + * @return A new {@link WatsonxAiEmbeddingOptions} instance. + */ + public static WatsonxAiEmbeddingOptions create() { + return new WatsonxAiEmbeddingOptions(); + } + + public static WatsonxAiEmbeddingOptions fromOptions(WatsonxAiEmbeddingOptions fromOptions) { + return new WatsonxAiEmbeddingOptions().withModel(fromOptions.getModel()); + } + public WatsonxAiEmbeddingOptions withModel(String model) { this.model = model; return this; } public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -41,16 +70,4 @@ public Integer getDimensions() { return null; } - /** - * Helper factory method to create a new {@link WatsonxAiEmbeddingOptions} instance. - * @return A new {@link WatsonxAiEmbeddingOptions} instance. - */ - public static WatsonxAiEmbeddingOptions create() { - return new WatsonxAiEmbeddingOptions(); - } - - public static WatsonxAiEmbeddingOptions fromOptions(WatsonxAiEmbeddingOptions fromOptions) { - return new WatsonxAiEmbeddingOptions().withModel(fromOptions.getModel()); - } - } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java index c76470a7aca..b78266e7aac 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.aot; import org.springframework.ai.watsonx.WatsonxAiChatOptions; diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java index 2de2f36fd06..7953f8c5608 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.api; import java.util.List; @@ -23,14 +24,14 @@ import com.ibm.cloud.sdk.core.security.IamToken; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.retry.annotation.Backoff; -import org.springframework.retry.annotation.Retryable; import reactor.core.publisher.Flux; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; +import org.springframework.retry.annotation.Backoff; +import org.springframework.retry.annotation.Retryable; import org.springframework.util.Assert; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -45,8 +46,10 @@ // @formatter:off public class WatsonxAiApi { - private static final Log logger = LogFactory.getLog(WatsonxAiApi.class); public static final String WATSONX_REQUEST_CANNOT_BE_NULL = "Watsonx Request cannot be null"; + + private static final Log logger = LogFactory.getLog(WatsonxAiApi.class); + private final RestClient restClient; private final WebClient webClient; private final IamAuthenticator iamAuthenticator; @@ -108,7 +111,7 @@ public ResponseEntity generate(WatsonxAiChatRequest watso return this.restClient.post() .uri(this.textEndpoint) .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) - .body(watsonxAiChatRequest.withProjectId(projectId)) + .body(watsonxAiChatRequest.withProjectId(this.projectId)) .retrieve() .toEntity(WatsonxAiChatResponse.class); } @@ -146,7 +149,7 @@ public ResponseEntity embeddings(WatsonxAiEmbeddingR return this.restClient.post() .uri(this.embeddingEndpoint) .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) - .body(request.withProjectId(projectId)) + .body(request.withProjectId(this.projectId)) .retrieve() .toEntity(WatsonxAiEmbeddingResponse.class); } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java index c228372cbcd..817e9802f2a 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.api; import java.util.Map; @@ -49,19 +50,18 @@ private WatsonxAiChatRequest(String input, Map parameters, Strin this.projectId = projectId; } + public static Builder builder(String input) { return new Builder(input); } + public WatsonxAiChatRequest withProjectId(String projectId) { this.projectId = projectId; return this; } - public String getInput() { return input; } + public String getInput() { return this.input; } - public Map getParameters() { return parameters; } + public Map getParameters() { return this.parameters; } - public String getModelId() { return modelId; } - - - public static Builder builder(String input) { return new Builder(input); } + public String getModelId() { return this.modelId; } public static class Builder { public static final String MODEL_PARAMETER_IS_REQUIRED = "Model parameter is required"; @@ -81,7 +81,7 @@ public Builder withParameters(Map parameters) { } public WatsonxAiChatRequest build() { - return new WatsonxAiChatRequest(input, parameters, model, ""); + return new WatsonxAiChatRequest(this.input, this.parameters, this.model, ""); } } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java index 36127771b35..f90ce643645 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.watsonx.api; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; +package org.springframework.ai.watsonx.api; import java.util.Date; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + /** * Java class for Watsonx.ai Chat Response object. * diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java index 316f8c2e478..ecb67f8d937 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.api; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java index 331dfa0a1af..8e8da278dff 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java @@ -1,10 +1,27 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx.api; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; -import java.util.List; +import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; /** * Java class for Watsonx.ai Embedding Request object. @@ -24,35 +41,35 @@ public class WatsonxAiEmbeddingRequest { @JsonProperty("project_id") String projectId; - public String getModel() { - return model; - } - - public List getInputs() { - return inputs; - } - private WatsonxAiEmbeddingRequest(String model, List inputs, String projectId) { this.model = model; this.inputs = inputs; this.projectId = projectId; } + public static Builder builder(List inputs) { + return new Builder(inputs); + } + + public String getModel() { + return this.model; + } + + public List getInputs() { + return this.inputs; + } + public WatsonxAiEmbeddingRequest withProjectId(String projectId) { this.projectId = projectId; return this; } - public static Builder builder(List inputs) { - return new Builder(inputs); - } - public static class Builder { - private String model = WatsonxAiEmbeddingOptions.DEFAULT_MODEL; - private final List inputs; + private String model = WatsonxAiEmbeddingOptions.DEFAULT_MODEL; + public Builder(List inputs) { this.inputs = inputs; } @@ -63,7 +80,7 @@ public Builder withModel(String model) { } public WatsonxAiEmbeddingRequest build() { - return new WatsonxAiEmbeddingRequest(model, inputs, ""); + return new WatsonxAiEmbeddingRequest(this.model, this.inputs, ""); } } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java index ec1ae022605..a2284afeeed 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java @@ -1,11 +1,27 @@ -package org.springframework.ai.watsonx.api; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; +package org.springframework.ai.watsonx.api; import java.util.Date; import java.util.List; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + /** * Java class for Watsonx.ai Embedding Response object. * @@ -16,4 +32,5 @@ public record WatsonxAiEmbeddingResponse(@JsonProperty("model_id") String model, @JsonProperty("created_at") Date createdAt, @JsonProperty("results") List results, @JsonProperty("input_token_count") Integer inputTokenCount) { + } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java index 975a1195e9e..a86dd12a242 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java @@ -1,10 +1,24 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx.api; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.List; - /** * Java class for Watsonx.ai Embedding Results object. * @@ -13,4 +27,5 @@ */ @JsonInclude(JsonInclude.Include.NON_NULL) public record WatsonxAiEmbeddingResults(@JsonProperty("embedding") float[] embedding) { + } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java index 75be3e17378..449ec8f7349 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.watsonx.utils; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; +package org.springframework.ai.watsonx.utils; import java.util.List; import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; + // @formatter:off public class MessageToPromptConverter { + public static final String TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS = "Tool execution results are not supported for watsonx.ai models"; + private static final String HUMAN_PROMPT = "Human: "; + private static final String ASSISTANT_PROMPT = "Assistant: "; - public static final String TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS = "Tool execution results are not supported for watsonx.ai models"; + private String humanPrompt = HUMAN_PROMPT; private String assistantPrompt = ASSISTANT_PROMPT; @@ -60,7 +64,7 @@ public String toPrompt(List messages) { .map(this::messageToString) .collect(Collectors.joining("\n")); - return String.format("%s%n%n%s%n%s", systemMessages, userMessages, assistantPrompt).trim(); + return String.format("%s%n%n%s%n%s", systemMessages, userMessages, this.assistantPrompt).trim(); } protected String messageToString(Message message) { @@ -68,9 +72,9 @@ protected String messageToString(Message message) { case SYSTEM: return message.getContent(); case USER: - return humanPrompt + message.getContent(); + return this.humanPrompt + message.getContent(); case ASSISTANT: - return assistantPrompt + message.getContent(); + return this.assistantPrompt + message.getContent(); case TOOL: throw new IllegalArgumentException(TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS); } diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java index 313e0bc2547..4a41f72f706 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx; import java.util.Date; @@ -25,10 +26,10 @@ import reactor.core.publisher.Flux; import reactor.test.StepVerifier; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.watsonx.api.WatsonxAiApi; @@ -57,7 +58,7 @@ public void testCreateRequestWithNoModelId() { Prompt prompt = new Prompt("Test message", options); Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> { - WatsonxAiChatRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = this.chatModel.request(prompt); }); } @@ -71,7 +72,7 @@ public void testCreateRequestSuccessfullyWithDefaultParams() { .build(); Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiChatRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = this.chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getParameters().get("decoding_method")).isEqualTo("greedy"); @@ -105,7 +106,7 @@ public void testCreateRequestSuccessfullyWithNonDefaultParams() { Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiChatRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = this.chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getParameters().get("decoding_method")).isEqualTo("sample"); @@ -139,7 +140,7 @@ public void testCreateRequestSuccessfullyWithChatDisabled() { Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiChatRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = this.chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getInput()).isEqualTo(msg); diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java index 42e6c0cc5bf..4e19920ec90 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java @@ -1,6 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx; +import java.util.Date; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.watsonx.api.WatsonxAiApi; @@ -9,9 +29,6 @@ import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResults; import org.springframework.http.ResponseEntity; -import java.util.Date; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -19,13 +36,13 @@ public class WatsonxAiEmbeddingModelTest { - private WatsonxAiApi watsonxAiApiMock; - private final WatsonxAiEmbeddingModel embeddingModel; + private WatsonxAiApi watsonxAiApiMock; + public WatsonxAiEmbeddingModelTest() { this.watsonxAiApiMock = mock(WatsonxAiApi.class); - this.embeddingModel = new WatsonxAiEmbeddingModel(watsonxAiApiMock); + this.embeddingModel = new WatsonxAiEmbeddingModel(this.watsonxAiApiMock); } @Test @@ -34,7 +51,7 @@ void createRequestWithOptions() { List inputs = List.of("test"); WatsonxAiEmbeddingOptions options = WatsonxAiEmbeddingOptions.create().withModel(MODEL); - WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, options); + WatsonxAiEmbeddingRequest request = this.embeddingModel.watsonxAiEmbeddingRequest(inputs, options); assertThat(request.getModel()).isEqualTo(MODEL); assertThat(request.getInputs().size()).isEqualTo(inputs.size()); @@ -46,7 +63,7 @@ void createRequestWithOptionsAndInvalidModel() { List inputs = List.of("test"); WatsonxAiEmbeddingOptions options = WatsonxAiEmbeddingOptions.create().withModel(MODEL); - WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, options); + WatsonxAiEmbeddingRequest request = this.embeddingModel.watsonxAiEmbeddingRequest(inputs, options); assertThat(request.getModel()).isEqualTo(WatsonxAiEmbeddingOptions.DEFAULT_MODEL); assertThat(request.getInputs().size()).isEqualTo(inputs.size()); @@ -55,7 +72,8 @@ void createRequestWithOptionsAndInvalidModel() { @Test void createRequestWithNoOptions() { List inputs = List.of("test"); - WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, EmbeddingOptions.EMPTY); + WatsonxAiEmbeddingRequest request = this.embeddingModel.watsonxAiEmbeddingRequest(inputs, + EmbeddingOptions.EMPTY); assertThat(request.getModel()).isEqualTo(WatsonxAiEmbeddingOptions.DEFAULT_MODEL); assertThat(request.getInputs().size()).isEqualTo(inputs.size()); @@ -73,14 +91,14 @@ void singleEmbeddingWithOptions() { inputTokenCount); ResponseEntity mockResponseEntity = ResponseEntity.ok(mockResponse); - when(watsonxAiApiMock.embeddings(any(WatsonxAiEmbeddingRequest.class))).thenReturn(mockResponseEntity); + when(this.watsonxAiApiMock.embeddings(any(WatsonxAiEmbeddingRequest.class))).thenReturn(mockResponseEntity); - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(2); + assertThat(this.embeddingModel.dimensions()).isEqualTo(2); } } diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java index fddaba9e55f..7d82cd7a4e3 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.watsonx.WatsonxAiChatOptions; @@ -22,8 +25,6 @@ import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java index f77812852f3..ac71fe43e8d 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.api; -import static org.assertj.core.api.Assertions.assertThat; +import java.util.List; +import java.util.Map; import org.junit.Test; import org.springframework.ai.watsonx.WatsonxAiChatOptions; -import java.util.List; -import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Pablo Sanchidrian Herrera diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java index 98d63092f7b..f5de7587476 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java @@ -1,10 +1,27 @@ -package org.springframework.ai.watsonx.api; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.watsonx.api; import org.junit.Test; + import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Pablo Sanchidrian Herrera */ diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/utils/MessageToPromptConverterTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/utils/MessageToPromptConverterTest.java index 5d22477c612..4005413127e 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/utils/MessageToPromptConverterTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/utils/MessageToPromptConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.utils; +import java.util.List; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.jupiter.api.Disabled; -import org.springframework.ai.chat.messages.Message; + import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; -import java.util.List; - /** * @author Pablo Sanchidrian Herrera * @author John Jairo Moreno Rojas @@ -36,64 +37,64 @@ public class MessageToPromptConverterTest { @Before public void setUp() { - converter = MessageToPromptConverter.create().withHumanPrompt("").withAssistantPrompt(""); + this.converter = MessageToPromptConverter.create().withHumanPrompt("").withAssistantPrompt(""); } @Test public void testSingleUserMessage() { Message userMessage = new UserMessage("User message"); String expected = "User message"; - Assert.assertEquals(expected, converter.messageToString(userMessage)); + Assert.assertEquals(expected, this.converter.messageToString(userMessage)); } @Test public void testSingleAssistantMessage() { Message assistantMessage = new AssistantMessage("Assistant message"); String expected = "Assistant message"; - Assert.assertEquals(expected, converter.messageToString(assistantMessage)); + Assert.assertEquals(expected, this.converter.messageToString(assistantMessage)); } @Test public void testSystemMessageType() { Message systemMessage = new SystemMessage("System message"); String expected = "System message"; - Assert.assertEquals(expected, converter.messageToString(systemMessage)); + Assert.assertEquals(expected, this.converter.messageToString(systemMessage)); } @Test public void testCustomHumanPrompt() { - converter.withHumanPrompt("Custom Human: "); + this.converter.withHumanPrompt("Custom Human: "); Message userMessage = new UserMessage("User message"); String expected = "Custom Human: User message"; - Assert.assertEquals(expected, converter.messageToString(userMessage)); + Assert.assertEquals(expected, this.converter.messageToString(userMessage)); } @Test public void testCustomAssistantPrompt() { - converter.withAssistantPrompt("Custom Assistant: "); + this.converter.withAssistantPrompt("Custom Assistant: "); Message assistantMessage = new AssistantMessage("Assistant message"); String expected = "Custom Assistant: Assistant message"; - Assert.assertEquals(expected, converter.messageToString(assistantMessage)); + Assert.assertEquals(expected, this.converter.messageToString(assistantMessage)); } @Test public void testEmptyMessageList() { String expected = ""; - Assert.assertEquals(expected, converter.toPrompt(List.of())); + Assert.assertEquals(expected, this.converter.toPrompt(List.of())); } @Test public void testSystemMessageList() { String msg = "this is a LLM prompt"; SystemMessage message = new SystemMessage(msg); - Assert.assertEquals(msg, converter.toPrompt(List.of(message))); + Assert.assertEquals(msg, this.converter.toPrompt(List.of(message))); } @Test public void testUserMessageList() { List messages = List.of(new UserMessage("User message")); String expected = "User message"; - Assert.assertEquals(expected, converter.toPrompt(messages)); + Assert.assertEquals(expected, this.converter.toPrompt(messages)); } } diff --git a/models/spring-ai-zhipuai/pom.xml b/models/spring-ai-zhipuai/pom.xml index 4c2c6e1791c..59df1857b1f 100644 --- a/models/spring-ai-zhipuai/pom.xml +++ b/models/spring-ai-zhipuai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 9a13668d6ec..b6b8e06590d 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -64,16 +76,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal ZhiPuAI} @@ -92,14 +94,14 @@ public class ZhiPuAiChatModel extends AbstractToolCallSupport implements ChatMod private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); /** - * The default options used for the chat completion requests. + * The retry template used to retry the ZhiPuAI API calls. */ - private final ZhiPuAiChatOptions defaultOptions; + public final RetryTemplate retryTemplate; /** - * The retry template used to retry the ZhiPuAI API calls. + * The default options used for the chat completion requests. */ - public final RetryTemplate retryTemplate; + private final ZhiPuAiChatOptions defaultOptions; /** * Low-level access to the ZhiPuAI API. @@ -176,6 +178,21 @@ public ZhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options, this.observationRegistry = observationRegistry; } + private static Generation buildGeneration(Choice choice, Map metadata) { + List toolCalls = choice.message().toolCalls() == null ? List.of() + : choice.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", + toolCall.function().name(), toolCall.function().arguments())) + .toList(); + + var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + return new Generation(assistantMessage, generationMetadata); + } + @Override public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); @@ -318,21 +335,6 @@ private ChatResponseMetadata from(ChatCompletion result) { .build(); } - private static Generation buildGeneration(Choice choice, Map metadata) { - List toolCalls = choice.message().toolCalls() == null ? List.of() - : choice.message() - .toolCalls() - .stream() - .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", - toolCall.function().name(), toolCall.function().arguments())) - .toList(); - - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); - String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); - return new Generation(assistantMessage, generationMetadata); - } - /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index e30ab9666d5..c0c66253a41 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; @@ -27,12 +35,6 @@ import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - /** * ZhiPuAiChatOptions represents the options for the ZhiPuAiChat model. * @@ -137,104 +139,23 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - protected ZhiPuAiChatOptions options; - - public Builder() { - this.options = new ZhiPuAiChatOptions(); - } - - public Builder(ZhiPuAiChatOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withTools(List tools) { - this.options.tools = tools; - return this; - } - - public Builder withToolChoice(String toolChoice) { - this.options.toolChoice = toolChoice; - return this; - } - - public Builder withUser(String user) { - this.options.user = user; - return this; - } - - public Builder withRequestId(String requestId) { - this.options.requestId = requestId; - return this; - } - - public Builder withDoSample(Boolean doSample) { - this.options.doSample = doSample; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public ZhiPuAiChatOptions build() { - return this.options; - } - + public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { + return ZhiPuAiChatOptions.builder() + .withModel(fromOptions.getModel()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withUser(fromOptions.getUser()) + .withRequestId(fromOptions.getRequestId()) + .withDoSample(fromOptions.getDoSample()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); } @Override @@ -317,7 +238,7 @@ public void setUser(String user) { } public String getRequestId() { - return requestId; + return this.requestId; } public void setRequestId(String requestId) { @@ -325,7 +246,7 @@ public void setRequestId(String requestId) { } public Boolean getDoSample() { - return doSample; + return this.doSample; } public void setDoSample(Boolean doSample) { @@ -344,7 +265,7 @@ public void setFunctionCallbacks(List functionCallbacks) { @Override public Set getFunctions() { - return functions; + return this.functions; } public void setFunctions(Set functionNames) { @@ -392,100 +313,127 @@ public void setToolContext(Map toolContext) { public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((tools == null) ? 0 : tools.hashCode()); - result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); - result = prime * result + ((user == null) ? 0 : user.hashCode()); - result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); - result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); + result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); + result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); + result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); + result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); return result; } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj; if (this.model == null) { - if (other.model != null) + if (other.model != null) { return false; + } } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; + } if (this.maxTokens == null) { - if (other.maxTokens != null) + if (other.maxTokens != null) { return false; + } } - else if (!this.maxTokens.equals(other.maxTokens)) + else if (!this.maxTokens.equals(other.maxTokens)) { return false; + } if (this.stop == null) { - if (other.stop != null) + if (other.stop != null) { return false; + } } - else if (!stop.equals(other.stop)) + else if (!this.stop.equals(other.stop)) { return false; + } if (this.temperature == null) { - if (other.temperature != null) + if (other.temperature != null) { return false; + } } - else if (!this.temperature.equals(other.temperature)) + else if (!this.temperature.equals(other.temperature)) { return false; + } if (this.topP == null) { - if (other.topP != null) + if (other.topP != null) { return false; + } } - else if (!topP.equals(other.topP)) + else if (!this.topP.equals(other.topP)) { return false; + } if (this.tools == null) { - if (other.tools != null) + if (other.tools != null) { return false; + } } - else if (!tools.equals(other.tools)) + else if (!this.tools.equals(other.tools)) { return false; + } if (this.toolChoice == null) { - if (other.toolChoice != null) + if (other.toolChoice != null) { return false; + } } - else if (!toolChoice.equals(other.toolChoice)) + else if (!this.toolChoice.equals(other.toolChoice)) { return false; + } if (this.user == null) { - if (other.user != null) + if (other.user != null) { return false; + } } - else if (!this.user.equals(other.user)) + else if (!this.user.equals(other.user)) { return false; + } if (this.requestId == null) { - if (other.requestId != null) + if (other.requestId != null) { return false; + } } - else if (!this.requestId.equals(other.requestId)) + else if (!this.requestId.equals(other.requestId)) { return false; + } if (this.doSample == null) { - if (other.doSample != null) + if (other.doSample != null) { return false; + } } - else if (!this.doSample.equals(other.doSample)) + else if (!this.doSample.equals(other.doSample)) { return false; + } if (this.proxyToolCalls == null) { - if (other.proxyToolCalls != null) + if (other.proxyToolCalls != null) { return false; + } } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { return false; + } if (this.toolContext == null) { - if (other.toolContext != null) + if (other.toolContext != null) { return false; + } } - else if (!this.toolContext.equals(other.toolContext)) + else if (!this.toolContext.equals(other.toolContext)) { return false; + } return true; } @@ -494,23 +442,104 @@ public ZhiPuAiChatOptions copy() { return fromOptions(this); } - public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { - return ZhiPuAiChatOptions.builder() - .withModel(fromOptions.getModel()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTools(fromOptions.getTools()) - .withToolChoice(fromOptions.getToolChoice()) - .withUser(fromOptions.getUser()) - .withRequestId(fromOptions.getRequestId()) - .withDoSample(fromOptions.getDoSample()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); + public static class Builder { + + protected ZhiPuAiChatOptions options; + + public Builder() { + this.options = new ZhiPuAiChatOptions(); + } + + public Builder(ZhiPuAiChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withUser(String user) { + this.options.user = user; + return this; + } + + public Builder withRequestId(String requestId) { + this.options.requestId = requestId; + return this; + } + + public Builder withDoSample(Boolean doSample) { + this.options.doSample = doSample; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public ZhiPuAiChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java index d330354208b..214679c32ec 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -39,10 +45,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; - /** * ZhiPuAI Embedding Model implementation. * diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java index cbd75ad4e82..02119d53c3b 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -42,6 +44,21 @@ public static Builder builder() { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + public static class Builder { protected ZhiPuAiEmbeddingOptions options; @@ -61,19 +78,4 @@ public ZhiPuAiEmbeddingOptions build() { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - @JsonIgnore - public Integer getDimensions() { - return null; - } - } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java index f7464cb79a2..cb267fd2fd7 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.List; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImageModel; @@ -30,8 +34,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * ZhiPuAiImageModel is a class that implements the ImageModel interface. It provides a * client for calling the ZhiPuAI image generation API. @@ -43,12 +45,12 @@ public class ZhiPuAiImageModel implements ImageModel { private final static Logger logger = LoggerFactory.getLogger(ZhiPuAiImageModel.class); + public final RetryTemplate retryTemplate; + private final ZhiPuAiImageOptions defaultOptions; private final ZhiPuAiImageApi zhiPuAiImageApi; - public final RetryTemplate retryTemplate; - public ZhiPuAiImageModel(ZhiPuAiImageApi zhiPuAiImageApi) { this(zhiPuAiImageApi, ZhiPuAiImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java index a6d1de3167e..baa1e8475f7 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.image.ImageOptions; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; -import java.util.Objects; - /** * ZhiPuAiImageOptions represents the options for image generation using ZhiPuAI image * model. @@ -64,30 +66,6 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private final ZhiPuAiImageOptions options; - - private Builder() { - this.options = new ZhiPuAiImageOptions(); - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withUser(String user) { - options.setUser(user); - return this; - } - - public ZhiPuAiImageOptions build() { - return options; - } - - } - @Override @JsonIgnore public Integer getN() { @@ -137,21 +115,47 @@ public void setUser(String user) { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ZhiPuAiImageOptions that)) + } + if (!(o instanceof ZhiPuAiImageOptions that)) { return false; - return Objects.equals(model, that.model) && Objects.equals(user, that.user); + } + return Objects.equals(this.model, that.model) && Objects.equals(this.user, that.user); } @Override public int hashCode() { - return Objects.hash(model, user); + return Objects.hash(this.model, this.user); } @Override public String toString() { - return "ZhiPuAiImageOptions{model='" + model + '\'' + ", user='" + user + '\'' + '}'; + return "ZhiPuAiImageOptions{model='" + this.model + '\'' + ", user='" + this.user + '\'' + '}'; + } + + public static class Builder { + + private final ZhiPuAiImageOptions options; + + private Builder() { + this.options = new ZhiPuAiImageOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public ZhiPuAiImageOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java index 51185977a56..d5cc2f21e55 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.aot; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index 0758d2e8f89..2be99709d4f 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; import java.util.Arrays; @@ -23,6 +24,12 @@ import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; @@ -37,13 +44,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - // @formatter:off /** * Single class implementation of the ZhiPuAI Chat Completion API and @@ -63,6 +63,8 @@ public class ZhiPuAiApi { private final WebClient webClient; + private final ZhiPuAiStreamFunctionCallingHelper chunkMerger = new ZhiPuAiStreamFunctionCallingHelper(); + /** * Create a new chat completion api with default base URL. * @@ -120,6 +122,111 @@ public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restCl .build(); } + public static String getTextContent(List content) { + return content.stream() + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); + } + + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v4/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * + * @param chatRequest The chat completion request. Must have the stream property set to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v4/chat/completions") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null), + this.chunkMerger::merge); + return List.of(monoChunk); + }) + .flatMap(mono -> mono); + } + + /** + * Creates an embedding vector representing the input text or token array. + * + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + * @param Type of the entity in the data list. Can be a {@link String} or {@link List} of tokens (e.g. + * Integers). For embedding multiple inputs in a single request, You can pass a {@link List} of {@link String} or + * {@link List} of {@link List} of tokens. For example: + * + *

{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
+ */ + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 512, "The list must be 512 dimensions or less"); + Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer + || list.get(0) instanceof List, + "The input must be either a String, or a List of Strings or list of list of integers."); + } + + return this.restClient.post() + .uri("/v4/embeddings") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + /** * ZhiPuAI Chat Completion Models: * ZhiPuAI Model. @@ -139,7 +246,7 @@ public enum ChatModel implements ChatModelDescription { } public String getValue() { - return value; + return this.value; } @Override @@ -148,6 +255,58 @@ public String getName() { } } + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") TOOL_CALL + } + + /** + * ZhiPuAI Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + /** + * DIMENSION: 1024 + */ + Embedding_2("Embedding-2"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + } + /** * Represents a tool the model may call. Currently, only functions are supported as a tool. * @@ -355,6 +514,15 @@ public record ChatCompletionMessage( @JsonProperty("tool_call_id") String toolCallId, @JsonProperty("tool_calls") List toolCalls) { + /** + * Create a chat completion message with the given content and role. All other fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null); + } + /** * Get message content as String. */ @@ -368,15 +536,6 @@ public String content() { throw new IllegalStateException("The content is not a string!"); } - /** - * Create a chat completion message with the given content and role. All other fields are null. - * @param content The contents of the message. - * @param role The role of the author of this message. - */ - public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null); - } - /** * The role of the author of this message. */ @@ -415,22 +574,6 @@ public record MediaContent( @JsonProperty("text") String text, @JsonProperty("image_url") ImageUrl imageUrl) { - /** - * @param url Either a URL of the image or the base64 encoded image data. - * The base64 encoded image data must have a special prefix in the following format: - * "data:{mimetype};base64,{base64-encoded-image-data}". - * @param detail Specifies the detail level of the image. - */ - @JsonInclude(Include.NON_NULL) - public record ImageUrl( - @JsonProperty("url") String url, - @JsonProperty("detail") String detail) { - - public ImageUrl(String url) { - this(url, null); - } - } - /** * Shortcut constructor for a text content. * @param text The text content of the message. @@ -446,6 +589,22 @@ public MediaContent(String text) { public MediaContent(ImageUrl imageUrl) { this("image_url", null, imageUrl); } + + /** + * @param url Either a URL of the image or the base64 encoded image data. + * The base64 encoded image data must have a special prefix in the following format: + * "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl( + @JsonProperty("url") String url, + @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + } } /** * The relevant tool call. @@ -475,43 +634,6 @@ public record ChatCompletionFunction( } } - public static String getTextContent(List content) { - return content.stream() - .filter(c -> "text".equals(c.type())) - .map(ChatCompletionMessage.MediaContent::text) - .reduce("", (a, b) -> a + b); - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("content_filter") CONTENT_FILTER, - /** - * The model called a tool. - */ - @JsonProperty("tool_calls") TOOL_CALLS, - /** - * (deprecated) The model called a function. - */ - @JsonProperty("function_call") FUNCTION_CALL, - /** - * Only for compatibility with Mistral AI API. - */ - @JsonProperty("tool_call") TOOL_CALL - } - /** * Represents a chat completion response returned by model, based on the provided input. * @@ -655,91 +777,6 @@ public record ChunkChoice( } } - /** - * Creates a model response for the given chat conversation. - * - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v4/chat/completions") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - private final ZhiPuAiStreamFunctionCallingHelper chunkMerger = new ZhiPuAiStreamFunctionCallingHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * - * @param chatRequest The chat completion request. Must have the stream property set to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri("/v4/chat/completions") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null), - this.chunkMerger::merge); - return List.of(monoChunk); - }) - .flatMap(mono -> mono); - } - - /** - * ZhiPuAI Embeddings Models: - * Embeddings. - */ - public enum EmbeddingModel { - - /** - * DIMENSION: 1024 - */ - Embedding_2("Embedding-2"); - - public final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - } - /** * Represents an embedding vector returned by embedding endpoint. * @@ -765,20 +802,20 @@ public Embedding(Integer index, float[] embedding) { @Override public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof Embedding embedding1)) return false; - return Objects.equals(index, embedding1.index) && Arrays.equals(embedding, embedding1.embedding) && Objects.equals(object, embedding1.object); + return Objects.equals(this.index, embedding1.index) && Arrays.equals(this.embedding, embedding1.embedding) && Objects.equals(this.object, embedding1.object); } @Override public int hashCode() { - int result = Objects.hash(index, object); - result = 31 * result + Arrays.hashCode(embedding); + int result = Objects.hash(this.index, this.object); + result = 31 * result + Arrays.hashCode(this.embedding); return result; } @Override public String toString() { return "Embedding{" + - "index=" + index + - ", embedding=" + Arrays.toString(embedding) + - ", object='" + object + '\'' + + "index=" + this.index + + ", embedding=" + Arrays.toString(this.embedding) + + ", object='" + this.object + '\'' + '}'; } } @@ -821,42 +858,5 @@ public record EmbeddingList( @JsonProperty("usage") Usage usage) { } - /** - * Creates an embedding vector representing the input text or token array. - * - * @param embeddingRequest The embedding request. - * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - * @param Type of the entity in the data list. Can be a {@link String} or {@link List} of tokens (e.g. - * Integers). For embedding multiple inputs in a single request, You can pass a {@link List} of {@link String} or - * {@link List} of {@link List} of tokens. For example: - * - *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
- */ - public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.input(), "The input can not be null."); - Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, - "The input must be either a String, or a List of Strings or List of List of integers."); - - if (embeddingRequest.input() instanceof List list) { - Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); - Assert.isTrue(list.size() <= 512, "The list must be 512 dimensions or less"); - Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer - || list.get(0) instanceof List, - "The input must be either a String, or a List of Strings or list of list of integers."); - } - - return this.restClient.post() - .uri("/v4/embeddings") - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - } // @formatter:on diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java index 304ec3146c2..23bfd8404d9 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; import java.util.List; -import java.util.function.Consumer; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.retry.RetryUtils; -import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * ZhiPuAI Image API. * @@ -76,6 +75,17 @@ public ZhiPuAiImageApi(String baseUrl, String zhiPuAiToken, RestClient.Builder r }).defaultStatusHandler(responseErrorHandler).build(); } + public ResponseEntity createImage(ZhiPuAiImageRequest zhiPuAiImageRequest) { + Assert.notNull(zhiPuAiImageRequest, "Image request cannot be null."); + Assert.hasLength(zhiPuAiImageRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("/v4/images/generations") + .body(zhiPuAiImageRequest) + .retrieve() + .toEntity(ZhiPuAiImageResponse.class); + } + /** * ZhiPuAI Image API model. * CogView @@ -113,22 +123,11 @@ public record ZhiPuAiImageResponse( @JsonProperty("created") Long created, @JsonProperty("data") List data) { } - - @JsonInclude(JsonInclude.Include.NON_NULL) - public record Data( - @JsonProperty("url") String url) { - } // @formatter:onn - public ResponseEntity createImage(ZhiPuAiImageRequest zhiPuAiImageRequest) { - Assert.notNull(zhiPuAiImageRequest, "Image request cannot be null."); - Assert.hasLength(zhiPuAiImageRequest.prompt(), "Prompt cannot be empty."); + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("url") String url) { - return this.restClient.post() - .uri("/v4/images/generations") - .body(zhiPuAiImageRequest) - .retrieve() - .toEntity(ZhiPuAiImageResponse.class); } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java index b74303729c4..e4629e94b49 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion.Choice; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionChunk; @@ -27,9 +31,6 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.LogProbs; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.List; - /** * Helper class to support Streaming function calling. It can merge the streamed * ChatCompletionChunk in case of function calling message. diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java index 36d0c4292c1..52f2427712e 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.zhipuai.api; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java index dc47c1c7f0e..88d197e9f48 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -27,10 +28,6 @@ */ public class ZhiPuAiUsage implements Usage { - public static ZhiPuAiUsage from(ZhiPuAiApi.Usage usage) { - return new ZhiPuAiUsage(usage); - } - private final ZhiPuAiApi.Usage usage; protected ZhiPuAiUsage(ZhiPuAiApi.Usage usage) { @@ -38,6 +35,10 @@ protected ZhiPuAiUsage(ZhiPuAiApi.Usage usage) { this.usage = usage; } + public static ZhiPuAiUsage from(ZhiPuAiApi.Usage usage) { + return new ZhiPuAiUsage(usage); + } + protected ZhiPuAiApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java index eb12b04c500..90dac9f579f 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.zhipuai.api.MockWeatherService; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java index d35ac839ab6..00a760cb1a2 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java index 0d68d135f8f..c1487282b15 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,31 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Geng Rong */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } /** @@ -65,28 +71,25 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } -} \ No newline at end of file +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java index 1837f9ba525..c40f7297653 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; +import java.util.List; +import java.util.Objects; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionChunk; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage; @@ -25,10 +31,6 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.Embedding; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.EmbeddingList; import org.springframework.http.ResponseEntity; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; import static org.assertj.core.api.Assertions.assertThat; @@ -43,7 +45,7 @@ public class ZhiPuAiApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = zhiPuAiApi + ResponseEntity response = this.zhiPuAiApi .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 0.7, false)); assertThat(response).isNotNull(); @@ -53,7 +55,7 @@ void chatCompletionEntity() { @Test void chatCompletionEntityWithMoreParams() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = zhiPuAiApi + ResponseEntity response = this.zhiPuAiApi .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 1024, null, false, 0.95, 0.7, null, null, null, "test_request_id", false)); @@ -64,7 +66,7 @@ void chatCompletionEntityWithMoreParams() { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = zhiPuAiApi + Flux response = this.zhiPuAiApi .chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 0.7, true)); assertThat(response).isNotNull(); @@ -73,7 +75,7 @@ void chatCompletionStream() { @Test void embeddings() { - ResponseEntity> response = zhiPuAiApi + ResponseEntity> response = this.zhiPuAiApi .embeddings(new ZhiPuAiApi.EmbeddingRequest<>("Hello world")); assertThat(response).isNotNull(); diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java index cf249ae1c43..2c6de05af21 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,17 @@ package org.springframework.ai.zhipuai.api; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage; @@ -32,10 +37,6 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.FunctionTool.Type; import org.springframework.http.ResponseEntity; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatModel.GLM_4; @@ -51,6 +52,15 @@ public class ZhiPuAiApiToolFunctionCallIT { ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("null") @Test public void toolFunctionCall() { @@ -92,7 +102,7 @@ public void toolFunctionCall() { ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, GLM_4.value, List.of(functionTool), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion = zhiPuAiApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.zhiPuAiApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -111,7 +121,7 @@ public void toolFunctionCall() { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, @@ -122,9 +132,9 @@ public void toolFunctionCall() { var functionResponseRequest = new ChatCompletionRequest(messages, GLM_4.value, List.of(functionTool), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion2 = zhiPuAiApi.chatCompletionEntity(functionResponseRequest); + ResponseEntity chatCompletion2 = this.zhiPuAiApi.chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); @@ -133,13 +143,4 @@ public void toolFunctionCall() { .containsAnyOf("30.0°C", "30°C", "30.0°F", "30°F"); } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java index aa76c1be887..af2d147505a 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; +import java.util.List; +import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.image.ImageMessage; @@ -49,10 +55,6 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -66,25 +68,6 @@ @ExtendWith(MockitoExtension.class) public class ZhiPuAiRetryTests { - private class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -101,14 +84,16 @@ public void onError(RetryContext context, RetryCallback @BeforeEach public void beforeEach() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); - - chatModel = new ZhiPuAiChatModel(zhiPuAiApi, ZhiPuAiChatOptions.builder().build(), null, retryTemplate); - embeddingModel = new ZhiPuAiEmbeddingModel(zhiPuAiApi, MetadataMode.EMBED, - ZhiPuAiEmbeddingOptions.builder().build(), retryTemplate); - imageModel = new ZhiPuAiImageModel(zhiPuAiImageApi, ZhiPuAiImageOptions.builder().build(), retryTemplate); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); + + this.chatModel = new ZhiPuAiChatModel(this.zhiPuAiApi, ZhiPuAiChatOptions.builder().build(), null, + this.retryTemplate); + this.embeddingModel = new ZhiPuAiEmbeddingModel(this.zhiPuAiApi, MetadataMode.EMBED, + ZhiPuAiEmbeddingOptions.builder().build(), this.retryTemplate); + this.imageModel = new ZhiPuAiImageModel(this.zhiPuAiImageApi, ZhiPuAiImageOptions.builder().build(), + this.retryTemplate); } @Test @@ -119,24 +104,24 @@ public void zhiPuAiChatTransientError() { ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, new ZhiPuAiApi.Usage(10, 10, 10)); - when(zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void zhiPuAiChatNonTransientError() { - when(zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -147,24 +132,24 @@ public void zhiPuAiChatStreamTransientError() { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, null); - when(zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void zhiPuAiChatStreamNonTransientError() { - when(zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).collectList().block()); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); } @Test @@ -173,24 +158,25 @@ public void zhiPuAiEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new ZhiPuAiApi.Usage(10, 10, 10)); - when(zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) + when(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingModel + var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(0); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void zhiPuAiEmbeddingNonTransientError() { - when(zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingModel + when(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } @@ -199,25 +185,44 @@ public void zhiPuAiImageTransientError() { var expectedResponse = new ZhiPuAiImageResponse(678l, List.of(new Data("url678"))); - when(zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) + when(this.zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); - var result = imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void zhiPuAiImageNonTransientError() { - when(zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) + when(this.zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) .thenThrow(new RuntimeException("Transient Error 1")); assertThrows(RuntimeException.class, - () -> imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + } + + private class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ActorsFilms.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ActorsFilms.java index 008ffecdb5d..26d1ec5ad9d 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ActorsFilms.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ActorsFilms.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.chat; import java.util.List; @@ -30,7 +31,7 @@ public ActorsFilms() { } public String getActor() { - return actor; + return this.actor; } public void setActor(String actor) { @@ -38,7 +39,7 @@ public void setActor(String actor) { } public List getMovies() { - return movies; + return this.movies; } public void setMovies(List movies) { @@ -47,7 +48,7 @@ public void setMovies(List movies) { @Override public String toString() { - return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}'; + return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}'; } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java index 45a38e29e59..5c0b3737491 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.chat; +import java.io.IOException; +import java.net.URL; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -47,16 +59,6 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; - -import java.io.IOException; -import java.net.URL; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -67,14 +69,14 @@ @EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+") class ZhiPuAiChatModelIT { + private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiChatModelIT.class); + @Autowired protected ChatModel chatModel; @Autowired protected StreamingChatModel streamingChatModel; - private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiChatModelIT.class); - @Value("classpath:/prompts/system-message.st") private Resource systemResource; @@ -82,10 +84,10 @@ class ZhiPuAiChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); @@ -95,10 +97,10 @@ void roleTest() { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = streamingChatModel.stream(prompt); + Flux flux = this.streamingChatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -146,7 +148,7 @@ void mapOutputConverter() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -165,14 +167,11 @@ void beanOutputConverter() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -185,7 +184,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -207,7 +206,7 @@ void beanStreamOutputConverterRecords() { Prompt prompt = new Prompt(promptTemplate.createMessage()); String generationTextFromStream = Objects - .requireNonNull(streamingChatModel.stream(prompt).collectList().block()) + .requireNonNull(this.streamingChatModel.stream(prompt).collectList().block()) .stream() .map(ChatResponse::getResults) .flatMap(List::stream) @@ -238,7 +237,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -264,7 +263,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = streamingChatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); String content = Objects.requireNonNull(response.collectList().block()) .stream() @@ -289,7 +288,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -305,7 +304,7 @@ void multiModalityImageUrl(String modelName) throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -320,7 +319,7 @@ void streamingMultiModalityImageUrl() throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = streamingChatModel.stream(new Prompt(List.of(userMessage), + Flux response = this.streamingChatModel.stream(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withModel(ZhiPuAiApi.ChatModel.GLM_4V.getValue()).build())); String content = Objects.requireNonNull(response.collectList().block()) @@ -335,4 +334,8 @@ void streamingMultiModalityImageUrl() throws IOException { assertThat(content).containsAnyOf("bowl", "basket"); } -} \ No newline at end of file + record ActorsFilmsRecord(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java index 162a56b4fb7..69fad096353 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class ZhiPuAiChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -77,7 +79,7 @@ void observationForChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -98,7 +100,7 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -119,7 +121,7 @@ void observationForStreamingChatOperation() { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java index 1371ecde5e4..447546d60db 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.embedding; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingModel; import org.springframework.ai.zhipuai.ZhiPuAiTestConfiguration; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -39,23 +41,23 @@ class EmbeddingIT { @Test void defaultEmbedding() { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World", "HI")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI")); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -65,7 +67,7 @@ void batchEmbedding() { assertThat(embeddingResponse.getResults().get(1)).isNotNull(); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1024); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java index 04f70b01087..9ad910595d7 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.embedding; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -35,8 +39,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -64,13 +66,13 @@ void observationForEmbeddingOperation() { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageModelIT.java index 618bdf48794..474bb499c57 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageModelIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.image; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageOptionsBuilder; @@ -45,7 +47,7 @@ void imageAsUrlTest() { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/mvnw b/mvnw index a16b5431b4c..657b412d449 100755 --- a/mvnw +++ b/mvnw @@ -1,22 +1,19 @@ #!/bin/sh -# ---------------------------------------------------------------------------- -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ---------------------------------------------------------------------------- # ---------------------------------------------------------------------------- # Maven Start Up Batch script diff --git a/pom.xml b/pom.xml index 46366508288..872ad4ed7fc 100644 --- a/pom.xml +++ b/pom.xml @@ -1,3 +1,19 @@ + + 4.0.0 @@ -225,6 +241,11 @@ 3.3.0 0.0.43 + 3.5.0 + true + true + 9.3 + true @@ -243,6 +264,50 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + ${maven-checkstyle-plugin.version} + + + com.puppycrawl.tools + checkstyle + ${puppycrawl-tools-checkstyle.version} + + + io.spring.javaformat + spring-javaformat-checkstyle + 0.0.43 + + + + + checkstyle-validation + validate + true + + ${disable.checks} + src/checkstyle/checkstyle.xml + src/checkstyle/checkstyle-header.txt + true + + checkstyle.build.directory=${project.build.directory} + checkstyle.suppressions.file=${project.basedir}/src/checkstyle/checkstyle-suppressions.xml + checkstyle.additional.suppressions.file=${project.basedir}/src/checkstyle/checkstyle-suppressions.xml + + true + ${maven-checkstyle-plugin.failsOnError} + + + ${maven-checkstyle-plugin.failOnViolation} + + + + check + + + + org.apache.maven.plugins maven-site-plugin diff --git a/settings.xml b/settings.xml index 890e9307091..e86c337870f 100644 --- a/settings.xml +++ b/settings.xml @@ -1,3 +1,19 @@ + + + + 4.0.0 diff --git a/spring-ai-core/pom.xml b/spring-ai-core/pom.xml index 8a25edad098..8322b5f665c 100644 --- a/spring-ai-core/pom.xml +++ b/spring-ai-core/pom.xml @@ -1,6 +1,23 @@ - + + + 4.0.0 org.springframework.ai @@ -21,6 +38,7 @@ 4.13.1 + false @@ -122,7 +140,7 @@ test - + diff --git a/spring-ai-core/src/main/java/org/springframework/ai/ResourceUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/ResourceUtils.java index a3f54a67cd5..8e48f220abd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/ResourceUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/ResourceUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai; import java.io.IOException; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java index 39cd3cd8f4d..286059bd3af 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; +import java.lang.reflect.Executable; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.aot.hint.TypeReference; import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider; import org.springframework.core.type.filter.AnnotationTypeFilter; import org.springframework.core.type.filter.TypeFilter; -import java.lang.reflect.Executable; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - /** * Utility methods for creating native runtime hints. See other modules for their * respective native runtime hints. @@ -89,8 +91,9 @@ public static Set findClassesInPackage(String packageName, TypeFi .stream()// .map(bd -> TypeReference.of(Objects.requireNonNull(bd.getBeanClassName())))// .peek(tr -> { - if (log.isDebugEnabled()) + if (log.isDebugEnabled()) { log.debug("registering [" + tr.getName() + ']'); + } }) .collect(Collectors.toUnmodifiableSet()); } @@ -154,4 +157,4 @@ private static Set> discoverJacksonAnnotatedTypesFromRootType(Class return jsonTypes; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/KnuddelsRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/KnuddelsRuntimeHints.java index fb676484d94..c7fd1dcb8a0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/aot/KnuddelsRuntimeHints.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/KnuddelsRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; import org.springframework.aot.hint.RuntimeHints; @@ -26,4 +27,4 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { hints.resources().registerResource(new ClassPathResource("/com/knuddels/jtokkit/cl100k_base.tiktoken")); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java index 393687becde..2ee283a69c4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; -import org.springframework.ai.chat.messages.Message; +import java.lang.reflect.Method; +import java.util.Set; + import org.springframework.ai.chat.messages.AbstractMessage; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -33,9 +37,6 @@ import org.springframework.lang.Nullable; import org.springframework.util.ReflectionUtils; -import java.lang.reflect.Method; -import java.util.Set; - public class SpringAiCoreRuntimeHints implements RuntimeHintsRegistrar { @Override @@ -56,9 +57,10 @@ public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader cla hints.reflection().registerMethod(getName, ExecutableMode.INVOKE); for (var r : Set.of("antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4", - "embedding/embedding-model-dimensions.properties")) + "embedding/embedding-model-dimensions.properties")) { hints.resources().registerResource(new ClassPathResource(r)); + } } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscription.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscription.java index c6de0ed68ee..ae89587fbc3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscription.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscription.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; +import java.util.Objects; + import org.springframework.ai.model.ModelResult; import org.springframework.lang.Nullable; -import java.util.Objects; - /** * Represents a response returned by the AI. * @@ -44,7 +45,7 @@ public String getOutput() { @Override public AudioTranscriptionMetadata getMetadata() { - return transcriptionMetadata != null ? transcriptionMetadata : AudioTranscriptionMetadata.NULL; + return this.transcriptionMetadata != null ? this.transcriptionMetadata : AudioTranscriptionMetadata.NULL; } public AudioTranscription withTranscriptionMetadata(@Nullable AudioTranscriptionMetadata transcriptionMetadata) { @@ -54,21 +55,24 @@ public AudioTranscription withTranscriptionMetadata(@Nullable AudioTranscription @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AudioTranscription that)) + } + if (!(o instanceof AudioTranscription that)) { return false; - return Objects.equals(text, that.text) && Objects.equals(transcriptionMetadata, that.transcriptionMetadata); + } + return Objects.equals(this.text, that.text) + && Objects.equals(this.transcriptionMetadata, that.transcriptionMetadata); } @Override public int hashCode() { - return Objects.hash(text, transcriptionMetadata); + return Objects.hash(this.text, this.transcriptionMetadata); } @Override public String toString() { - return "Transcript{" + "text=" + text + ", transcriptionMetadata=" + transcriptionMetadata + '}'; + return "Transcript{" + "text=" + this.text + ", transcriptionMetadata=" + this.transcriptionMetadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionMetadata.java index bd064a6596c..5fc1ea10690 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; import org.springframework.ai.model.ResultMetadata; @@ -32,6 +33,7 @@ public interface AudioTranscriptionMetadata extends ResultMetadata { */ static AudioTranscriptionMetadata create() { return new AudioTranscriptionMetadata() { + }; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionOptions.java index 95bd877e717..7fec8fa97d2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; import org.springframework.ai.model.ModelOptions; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionPrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionPrompt.java index 6f5208240e3..07ca8f644eb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionPrompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionPrompt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; import org.springframework.ai.model.ModelRequest; @@ -57,12 +58,12 @@ public AudioTranscriptionPrompt(Resource audioResource, AudioTranscriptionOption @Override public Resource getInstructions() { - return audioResource; + return this.audioResource; } @Override public AudioTranscriptionOptions getOptions() { - return modelOptions; + return this.modelOptions; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java index e1a652355d0..6bbe17c51f9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.audio.transcription; -import org.springframework.ai.model.ModelResponse; +package org.springframework.ai.audio.transcription; import java.util.List; +import org.springframework.ai.model.ModelResponse; + /** * @author Michael Lavelle * @author Piotr Olaszewski @@ -42,17 +43,17 @@ public AudioTranscriptionResponse(AudioTranscription transcript, @Override public AudioTranscription getResult() { - return transcript; + return this.transcript; } @Override public List getResults() { - return List.of(transcript); + return List.of(this.transcript); } @Override public AudioTranscriptionResponseMetadata getMetadata() { - return transcriptionResponseMetadata; + return this.transcriptionResponseMetadata; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponseMetadata.java index 66c3fdf8972..7c4d4941189 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; import org.springframework.ai.model.MutableResponseMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 6f8d2dad691..fd282fd319d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client; import java.net.URL; @@ -21,6 +22,9 @@ import java.util.Map; import java.util.function.Consumer; +import io.micrometer.observation.ObservationRegistry; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.messages.Message; @@ -36,9 +40,6 @@ import org.springframework.core.io.Resource; import org.springframework.util.MimeType; -import io.micrometer.observation.ObservationRegistry; -import reactor.core.publisher.Flux; - /** * Client to perform stateless requests to an AI Model, using a fluent API. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClientCustomizer.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClientCustomizer.java index b1cdf3cc2f2..bbad16c2835 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClientCustomizer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClientCustomizer.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index af8fc31c168..c0b2d5a8d22 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -28,12 +28,18 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; @@ -62,12 +68,6 @@ import org.springframework.util.MimeType; import org.springframework.util.StringUtils; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; - /** * The default implementation of {@link ChatClient} as created by the * {@link Builder#build()} } method. @@ -91,6 +91,30 @@ public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) this.defaultChatClientRequest = defaultChatClientRequest; } + private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, String formatParam) { + Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); + if (StringUtils.hasText(formatParam)) { + advisorContext.put("formatParam", formatParam); + } + + return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, + inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, + inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, + inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, + inputRequest.toolContext); + } + + public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, + ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { + + return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), + advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), + advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), + advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), + advisedRequest.advisorParams(), observationRegistry, customObservationConvention, + advisedRequest.toolContext()); + } + @Override public ChatClientRequestSpec prompt() { return new DefaultChatClientRequestSpec(this.defaultChatClientRequest); @@ -145,12 +169,12 @@ public Builder mutate() { public static class DefaultPromptUserSpec implements PromptUserSpec { - private String text = ""; - private final Map params = new HashMap<>(); private final List media = new ArrayList<>(); + private String text = ""; + @Override public PromptUserSpec media(Media... media) { this.media.addAll(Arrays.asList(media)); @@ -220,10 +244,10 @@ protected List media() { public static class DefaultPromptSystemSpec implements PromptSystemSpec { - private String text = ""; - private final Map params = new HashMap<>(); + private String text = ""; + @Override public PromptSystemSpec text(String text) { this.text = text; @@ -296,11 +320,11 @@ public AdvisorSpec advisors(List advisors) { } public List getAdvisors() { - return advisors; + return this.advisors; } public Map getParams() { - return params; + return this.params; } } @@ -426,12 +450,12 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ var initialAdvisedRequest = toAdvisedRequest(inputRequest, ""); - // @formatter:off + // @formatter:off // Apply the around advisor chain that terminates with the, last, // model call advisor. - Flux stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest); + Flux stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest); - return stream + return stream .map(AdvisedResponse::response) .doOnError(observation::error) .doFinally(s -> observation.stop()) @@ -464,12 +488,6 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final ChatModel chatModel; - private String userText = ""; - - private String systemText = ""; - - private ChatOptions chatOptions; - private final List media = new ArrayList<>(); private final List functionNames = new ArrayList<>(); @@ -490,61 +508,11 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final Map toolContext = new HashMap<>(); - private ObservationRegistry getObservationRegistry() { - return this.observationRegistry; - } - - private ChatClientObservationConvention getCustomObservationConvention() { - return this.customObservationConvention; - } - - public String getUserText() { - return this.userText; - } - - public Map getUserParams() { - return this.userParams; - } - - public String getSystemText() { - return this.systemText; - } - - public Map getSystemParams() { - return this.systemParams; - } - - public ChatOptions getChatOptions() { - return this.chatOptions; - } - - public List getAdvisors() { - return this.advisors; - } - - public Map getAdvisorParams() { - return this.advisorParams; - } - - public List getMessages() { - return this.messages; - } - - public List getMedia() { - return this.media; - } - - public List getFunctionNames() { - return this.functionNames; - } + private String userText = ""; - public List getFunctionCallbacks() { - return this.functionCallbacks; - } + private String systemText = ""; - public Map getToolContext() { - return this.toolContext; - } + private ChatOptions chatOptions; /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { @@ -578,13 +546,13 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { return chatModel.stream(advisedRequest.toPrompt()) - .map( chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext()))) - .publishOn(Schedulers.boundedElastic());// TODO add option to disable. + .map(chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext()))) + .publishOn(Schedulers.boundedElastic()); // TODO add option to disable. } }); // @formatter:on @@ -624,13 +592,69 @@ public Flux aroundStream(AdvisedRequest advisedRequest, StreamA .pushAll(this.advisors); } + private ObservationRegistry getObservationRegistry() { + return this.observationRegistry; + } + + private ChatClientObservationConvention getCustomObservationConvention() { + return this.customObservationConvention; + } + + public String getUserText() { + return this.userText; + } + + public Map getUserParams() { + return this.userParams; + } + + public String getSystemText() { + return this.systemText; + } + + public Map getSystemParams() { + return this.systemParams; + } + + public ChatOptions getChatOptions() { + return this.chatOptions; + } + + public List getAdvisors() { + return this.advisors; + } + + public Map getAdvisorParams() { + return this.advisorParams; + } + + public List getMessages() { + return this.messages; + } + + public List getMedia() { + return this.media; + } + + public List getFunctionNames() { + return this.functionNames; + } + + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + public Map getToolContext() { + return this.toolContext; + } + /** * Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose * settings are replicated from this {@code ChatClientRequest}. */ public Builder mutate() { DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient - .builder(chatModel, this.observationRegistry, this.customObservationConvention) + .builder(this.chatModel, this.observationRegistry, this.customObservationConvention) .defaultSystem(s -> s.text(this.systemText).params(this.systemParams)) .defaultUser(u -> u.text(this.userText) .params(this.userParams) @@ -827,30 +851,6 @@ public StreamResponseSpec stream() { } - private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, String formatParam) { - Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); - if (StringUtils.hasText(formatParam)) { - advisorContext.put("formatParam", formatParam); - } - - return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, - inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, - inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, - inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, - inputRequest.toolContext); - } - - public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, - ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { - - return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), - advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), - advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), - advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), - advisedRequest.advisorParams(), observationRegistry, customObservationConvention, - advisedRequest.toolContext()); - } - // Prompt public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec { @@ -877,7 +877,7 @@ public ChatResponse chatResponse() { } private ChatResponse doGetChatResponse(Prompt prompt) { - return chatModel.call(prompt); + return this.chatModel.call(prompt); } } @@ -913,4 +913,4 @@ public Flux content() { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 1053a5d02f1..6b03d8e2a40 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ import java.util.Map; import java.util.function.Consumer; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.client.ChatClient.Builder; import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec; import org.springframework.ai.chat.client.ChatClient.PromptUserSpec; @@ -35,8 +37,6 @@ import org.springframework.core.io.Resource; import org.springframework.util.Assert; -import io.micrometer.observation.ObservationRegistry; - /** * DefaultChatClientBuilder is a builder class for creating a ChatClient. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index bb6e0ae5491..fa1d1526a86 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,18 +21,18 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import reactor.core.publisher.Flux; - /** * Advisor called before and after the {@link ChatModel#call(Prompt)} and * {@link ChatModel#stream(Prompt)} methods calls. The {@link ChatClient} maintains a @@ -90,4 +90,4 @@ default Flux aroundStream(AdvisedRequest advisedRequest, Stream .map(chatResponse -> new AdvisedResponse(chatResponse, context)); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java index 069f46aa675..b6ab8fedda7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,4 +36,5 @@ public R getResponse() { public E getEntity() { return this.entity; } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 45f4bf8a705..304e538656f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,10 @@ import java.util.Map; import java.util.function.Function; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -26,11 +30,6 @@ import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.util.Assert; -import org.stringtemplate.v4.compiler.CodeGenerator.includeExpr_return; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; /** * Abstract class that serves as a base for chat memory advisors. @@ -129,7 +128,7 @@ protected Flux doNextWithProtectFromBlockingBefore(AdvisedReque } public static abstract class AbstractBuilder { - + protected String conversationId = DEFAULT_CHAT_MEMORY_CONVERSATION_ID; protected int chatMemoryRetrieveSize = DEFAULT_CHAT_MEMORY_RESPONSE_SIZE; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index fe5e3b4eaa5..ee441d1d686 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor; import java.util.ArrayList; @@ -20,6 +21,10 @@ import java.util.List; import java.util.concurrent.ConcurrentLinkedDeque; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -35,10 +40,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; - /** * Implementation of the {@link CallAroundAdvisorChain} and * {@link StreamAroundAdvisorChain}. Used by the @@ -71,6 +72,10 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream this.streamAroundAdvisors = streamAroundAdvisors; } + public static Builder builder(ObservationRegistry observationRegistry) { + return new Builder(observationRegistry); + } + @Override public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { @@ -117,17 +122,13 @@ public Flux nextAroundStream(AdvisedRequest advisedRequest) { // @formatter:off return Flux.defer(() -> advisor.aroundStream(advisedRequest, this)) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on }); } - public static Builder builder(ObservationRegistry observationRegistry) { - return new Builder(observationRegistry); - } - public static class Builder { private final ObservationRegistry observationRegistry; @@ -195,4 +196,4 @@ public DefaultAroundAdvisorChain build() { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/LastMaxTokenSizeContentPurger.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/LastMaxTokenSizeContentPurger.java index 2c20da81ecc..8534a27503d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/LastMaxTokenSizeContentPurger.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/LastMaxTokenSizeContentPurger.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,13 @@ package org.springframework.ai.chat.client.advisor; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.model.Content; import org.springframework.ai.model.MediaContent; import org.springframework.ai.tokenizer.TokenCountEstimator; -import java.util.ArrayList; -import java.util.List; - /** * Returns a new list of content (e.g list of messages of list of documents) that is a * subset of the input list of contents and complies with the max token size constraint. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 0677e28ad19..6aa2ca1bcd6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ import java.util.ArrayList; import java.util.List; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -29,8 +31,6 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.MessageAggregator; -import reactor.core.publisher.Flux; - /** * Memory is retrieved added as a collection of messages to the prompt * @@ -52,6 +52,10 @@ public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversatio super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order); } + public static Builder builder(ChatMemory chatMemory) { + return new Builder(chatMemory); + } + @Override public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { @@ -107,10 +111,6 @@ private void observeAfter(AdvisedResponse advisedResponse) { this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); } - public static Builder builder(ChatMemory chatMemory) { - return new Builder(chatMemory); - } - public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder { protected Builder(ChatMemory chatMemory) { @@ -124,4 +124,4 @@ public MessageChatMemoryAdvisor build() { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index d183ab31697..2989fa64ce1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,8 @@ import java.util.Map; import java.util.stream.Collectors; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -33,8 +35,6 @@ import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.model.Content; -import reactor.core.publisher.Flux; - /** * Memory is retrieved added into the prompt's system text. * @@ -77,6 +77,10 @@ public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversation this.systemTextAdvise = systemTextAdvise; } + public static Builder builder(ChatMemory chatMemory) { + return new Builder(chatMemory); + } + @Override public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { @@ -140,10 +144,6 @@ private void observeAfter(AdvisedResponse advisedResponse) { this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); } - public static Builder builder(ChatMemory chatMemory) { - return new Builder(chatMemory); - } - public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder { private String systemTextAdvise = DEFAULT_SYSTEM_TEXT_ADVISE; @@ -164,4 +164,4 @@ public PromptChatMemoryAdvisor build() { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index 33fb8b0f428..b89238c0ff3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,10 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; @@ -38,10 +42,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - /** * Context for the question is retrieved from a Vector Store and added to the prompt's * user text. @@ -51,6 +51,10 @@ */ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents"; + + public static final String FILTER_EXPRESSION = "qa_filter_expression"; + private static final String DEFAULT_USER_TEXT_ADVISE = """ Context information is below, surrounded by --------------------- @@ -72,10 +76,6 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv private final SearchRequest searchRequest; - public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents"; - - public static final String FILTER_EXPRESSION = "qa_filter_expression"; - private final boolean protectFromBlocking; private final int order; @@ -159,6 +159,10 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques this.order = order; } + public static Builder builder(VectorStore vectorStore) { + return new Builder(vectorStore); + } + @Override public String getName() { return this.getClass().getSimpleName(); @@ -262,11 +266,7 @@ private Predicate onFinishReason() { .isPresent(); } - public static Builder builder(VectorStore vectorStore) { - return new Builder(vectorStore); - } - - public static class Builder { + public static final class Builder { private final VectorStore vectorStore; @@ -312,4 +312,4 @@ public QuestionAnswerAdvisor build() { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java index 0706320b8de..054e3fbf0c2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client.advisor; import java.util.List; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; @@ -29,8 +32,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; - /** * A {@link CallAroundAdvisor} and {@link StreamAroundAdvisor} that filters out the * response if the user input contains any of the sensitive words. @@ -62,6 +63,10 @@ public SafeGuardAdvisor(List sensitiveWords, String failureResponse, int this.order = order; } + public static Builder builder() { + return new Builder(); + } + public String getName() { return this.getClass().getSimpleName(); } @@ -82,7 +87,7 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { if (!CollectionUtils.isEmpty(this.sensitiveWords) - && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { + && this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { return Flux.just(createFailureResponse(advisedRequest)); } @@ -100,11 +105,7 @@ public int getOrder() { return this.order; } - public static Builder builder() { - return new Builder(); - } - - public static class Builder { + public static final class Builder { private List sensitiveWords; @@ -136,4 +137,4 @@ public SafeGuardAdvisor build() { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java index d23388682e6..1eb8a2e3aee 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,24 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client.advisor; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.model.ModelOptionsUtils; -import reactor.core.publisher.Flux; - /** * A simple logger advisor that logs the request and response messages. * @@ -38,10 +39,6 @@ */ public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { - private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); - - private int order; - public static final Function DEFAULT_REQUEST_TO_STRING = (request) -> { return request.toString(); }; @@ -50,10 +47,14 @@ public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvis return ModelOptionsUtils.toJsonString(response); }; + private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); + private final Function requestToString; private final Function responseToString; + private int order; + public SimpleLoggerAdvisor() { this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, 0); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java index c844f5dfe87..02acf521223 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,8 @@ import java.util.Map; import java.util.stream.Collectors; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -36,8 +38,6 @@ import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; -import reactor.core.publisher.Flux; - /** * Memory is retrieved from a VectorStore added into the prompt's system text. * @@ -99,6 +99,10 @@ public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConve this.systemTextAdvise = systemTextAdvise; } + public static Builder builder(VectorStore chatMemory) { + return new Builder(chatMemory); + } + @Override public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { @@ -185,10 +189,6 @@ else if (message instanceof AssistantMessage assistantMessage) { return docs; } - public static Builder builder(VectorStore chatMemory) { - return new Builder(chatMemory); - } - public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder { private String systemTextAdvise = DEFAULT_SYSTEM_TEXT_ADVISE; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 032de8b63d0..afca774760f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ import java.util.Map; import java.util.function.Function; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -32,6 +31,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.util.CollectionUtils; @@ -63,12 +63,6 @@ public record AdvisedRequest(ChatModel chatModel, String userText, String system Map userParams, Map systemParams, List advisors, Map advisorParams, Map adviseContext, Map toolContext) { - public AdvisedRequest updateContext(Function, Map> contextTransform) { - return from(this) - .withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) - .build(); - } - public static Builder from(AdvisedRequest from) { Builder builder = new Builder(); builder.chatModel = from.chatModel; @@ -93,8 +87,60 @@ public static Builder builder() { return new Builder(); } + public AdvisedRequest updateContext(Function, Map> contextTransform) { + return from(this) + .withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) + .build(); + } + + public Prompt toPrompt() { + + var messages = new ArrayList(this.messages()); + + String processedSystemText = this.systemText(); + if (StringUtils.hasText(processedSystemText)) { + if (!CollectionUtils.isEmpty(this.systemParams())) { + processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render(); + } + messages.add(new SystemMessage(processedSystemText)); + } + + String formatParam = (String) this.adviseContext().get("formatParam"); + + var processedUserText = StringUtils.hasText(formatParam) + ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); + + if (StringUtils.hasText(processedUserText)) { + + Map userParams = new HashMap<>(this.userParams()); + if (StringUtils.hasText(formatParam)) { + userParams.put("spring_ai_soc_format", formatParam); + } + if (!CollectionUtils.isEmpty(userParams)) { + processedUserText = new PromptTemplate(processedUserText, userParams).render(); + } + messages.add(new UserMessage(processedUserText, this.media())); + } + + if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!this.functionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); + } + if (!this.functionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(this.functionCallbacks()); + } + if (!CollectionUtils.isEmpty(this.toolContext())) { + functionCallingOptions.setToolContext(this.toolContext()); + } + } + + return new Prompt(messages, this.chatOptions()); + } + public static class Builder { + public Map toolContext = Map.of(); + private ChatModel chatModel; private String userText = ""; @@ -121,8 +167,6 @@ public static class Builder { private Map adviseContext = Map.of(); - public Map toolContext = Map.of(); - public Builder withChatModel(ChatModel chatModel) { this.chatModel = chatModel; return this; @@ -194,55 +238,11 @@ public Builder withAdviseContext(Map adviseContext) { } public AdvisedRequest build() { - return new AdvisedRequest(chatModel, this.userText, this.systemText, this.chatOptions, this.media, + return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media, this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, this.advisors, this.advisorParams, this.adviseContext, this.toolContext); } } - public Prompt toPrompt() { - - var messages = new ArrayList(this.messages()); - - String processedSystemText = this.systemText(); - if (StringUtils.hasText(processedSystemText)) { - if (!CollectionUtils.isEmpty(this.systemParams())) { - processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render(); - } - messages.add(new SystemMessage(processedSystemText)); - } - - String formatParam = (String) this.adviseContext().get("formatParam"); - - var processedUserText = StringUtils.hasText(formatParam) - ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); - - if (StringUtils.hasText(processedUserText)) { - - Map userParams = new HashMap<>(this.userParams()); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); - } - if (!CollectionUtils.isEmpty(userParams)) { - processedUserText = new PromptTemplate(processedUserText, userParams).render(); - } - messages.add(new UserMessage(processedUserText, this.media())); - } - - if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!this.functionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); - } - if (!this.functionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(this.functionCallbacks()); - } - if (!CollectionUtils.isEmpty(this.toolContext())) { - functionCallingOptions.setToolContext(this.toolContext()); - } - } - - return new Prompt(messages, this.chatOptions()); - } - -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java index 8c81740cfd3..a03247fd6dd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; import java.util.Collections; @@ -29,15 +30,15 @@ */ public record AdvisedResponse(ChatResponse response, Map adviseContext) { - public AdvisedResponse updateContext(Function, Map> contextTransform) { - return new AdvisedResponse(response, - Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(adviseContext)))); - } - public static Builder builder() { return new Builder(); } + public AdvisedResponse updateContext(Function, Map> contextTransform) { + return new AdvisedResponse(this.response, + Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))); + } + public static class Builder { private ChatResponse response; @@ -66,8 +67,9 @@ public Builder withAdviseContext(Map adviseContext) { public AdvisedResponse build() { Assert.notNull(this.adviseContext, "the adviseContext must be non-null"); - return new AdvisedResponse(response, adviseContext); + return new AdvisedResponse(this.response, this.adviseContext); } } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java index c7a931b8504..c03eb507e99 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; import org.springframework.core.Ordered; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java index 57d19df600e..05369aace23 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; /** @@ -31,4 +32,4 @@ public interface CallAroundAdvisor extends Advisor { */ AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java index 9a51a01faca..9158a721265 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; /** @@ -34,4 +35,4 @@ public interface CallAroundAdvisorChain { */ AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java index eeb65aa666e..56ff624f0f1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; import reactor.core.publisher.Flux; @@ -32,4 +33,4 @@ public interface StreamAroundAdvisor extends Advisor { */ Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java index 43837fd43f8..175ae9e71fa 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; import reactor.core.publisher.Flux; @@ -36,4 +37,4 @@ public interface StreamAroundAdvisorChain { */ Flux nextAroundStream(AdvisedRequest advisedRequest); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java index 9effbf29bd1..394e89efd2b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java @@ -1,30 +1,31 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.observation; import java.util.Map; +import io.micrometer.observation.Observation; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import io.micrometer.observation.Observation; - /** * Context used to store metadata for chat client advisors. * @@ -34,16 +35,15 @@ */ public class AdvisorObservationContext extends Observation.Context { - public enum Type { - - BEFORE, AFTER, AROUND - - } - private final String advisorName; private final Type advisorType; + /** + * The order of the advisor in the advisor chain. + */ + private final int order; + /** * The {@link AdvisedRequest} data to be advised. Represents the row * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. @@ -65,11 +65,6 @@ public enum Type { @Nullable private Map advisorResponseContext; - /** - * The order of the advisor in the advisor chain. - */ - private final int order; - public AdvisorObservationContext(String advisorName, Type advisorType, @Nullable AdvisedRequest advisorRequest, @Nullable Map advisorRequestContext, @Nullable Map advisorResponseContext, int order) { @@ -84,6 +79,10 @@ public AdvisorObservationContext(String advisorName, Type advisorType, @Nullable this.order = order; } + public static Builder builder() { + return new Builder(); + } + public String getAdvisorName() { return this.advisorName; } @@ -123,8 +122,10 @@ public int getOrder() { return this.order; } - public static Builder builder() { - return new Builder(); + public enum Type { + + BEFORE, AFTER, AROUND + } public static class Builder { @@ -172,10 +173,10 @@ public Builder withOrder(int order) { } public AdvisorObservationContext build() { - return new AdvisorObservationContext(advisorName, advisorType, advisorRequest, advisorRequestContext, - advisorResponseContext, order); + return new AdvisorObservationContext(this.advisorName, this.advisorType, this.advisorRequest, + this.advisorRequestContext, this.advisorResponseContext, this.order); } } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationConvention.java index 7726c65e99c..10e5212c1ef 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationConvention.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.observation; import io.micrometer.observation.Observation; @@ -32,4 +33,4 @@ default boolean supportsContext(Observation.Context context) { return context instanceof AdvisorObservationContext; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java index 9d1d4fb17f6..835dfe614fc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.observation; import io.micrometer.common.docs.KeyName; @@ -94,4 +95,4 @@ public String asString() { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java index 9e7d5721b7d..2b4ce63cd4e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java @@ -1,29 +1,30 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.observation; +import io.micrometer.common.KeyValue; +import io.micrometer.common.KeyValues; + import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.util.ParsingUtils; import org.springframework.lang.Nullable; -import io.micrometer.common.KeyValue; -import io.micrometer.common.KeyValues; - /** * @author Christian Tzolov * @since 1.0.0 diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/package-info.java index c84a0a9209d..d901e06c610 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java index 78f16c9b255..bd9918d5631 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.client.observation; -import org.springframework.ai.observation.tracing.TracingHelper; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; +package org.springframework.ai.chat.client.observation; import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationFilter; +import org.springframework.ai.observation.tracing.TracingHelper; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + /** * An {@link ObservationFilter} to include the chat prompt content in the observation. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java index fbf9557f24d..6ad0d244fac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java @@ -1,26 +1,27 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.observation; +import io.micrometer.observation.Observation; + import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; - -import io.micrometer.observation.Observation; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -50,6 +51,10 @@ public class ChatClientObservationContext extends Observation.Context { this.stream = isStream; } + public static Builder builder() { + return new Builder(); + } + public DefaultChatClientRequestSpec getRequest() { return this.request; } @@ -71,11 +76,7 @@ public void setFormat(@Nullable String format) { this.format = format; } - public static Builder builder() { - return new Builder(); - } - - public static class Builder { + public static final class Builder { private DefaultChatClientRequestSpec request; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationConvention.java index 047608d8ef9..8431f3088e6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationConvention.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.observation; import io.micrometer.observation.Observation; @@ -31,4 +32,4 @@ default boolean supportsContext(Observation.Context context) { return context instanceof ChatClientObservationContext; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java index 09c34bd45d4..c39f3ffd40e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.observation; import io.micrometer.common.docs.KeyName; @@ -147,7 +148,7 @@ public String asString() { public String asString() { return "spring.ai.chat.client.system.params"; } - }; + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java index e2ab324778d..b51ade63ca2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java @@ -1,20 +1,24 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.observation; +import io.micrometer.common.KeyValue; +import io.micrometer.common.KeyValues; + import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; @@ -24,9 +28,6 @@ import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; -import io.micrometer.common.KeyValue; -import io.micrometer.common.KeyValues; - /** * Default conventions to populate observations for chat client workflows. * @@ -136,4 +137,4 @@ protected KeyValues toolFunctionCallbacks(KeyValues keyValues, ChatClientObserva .asString(), TracingHelper.concatenateStrings(functionCallbacks)); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/package-info.java index 4a91aabb600..b624f47a7cd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java index c4e9b33f0db..7003457df16 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,4 +42,4 @@ default void add(String conversationId, Message message) { void clear(String conversationId); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java index 4e8578e2f08..73318fc8c4b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,4 +56,4 @@ public void clear(String conversationId) { this.conversationHistory.remove(conversationId); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 05a89117c6b..8fa9514a305 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.messages; -import org.springframework.core.io.Resource; -import org.springframework.util.Assert; -import org.springframework.util.StreamUtils; +package org.springframework.ai.chat.messages; import java.io.IOException; import java.io.InputStream; @@ -26,6 +23,10 @@ import java.util.Map; import java.util.Objects; +import org.springframework.core.io.Resource; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; + /** * The AbstractMessage class is an abstract implementation of the Message interface. It * provides a base implementation for message content, media attachments, metadata, and @@ -87,17 +88,19 @@ public MessageType getMessageType() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AbstractMessage that)) + } + if (!(o instanceof AbstractMessage that)) { return false; - return messageType == that.messageType && Objects.equals(textContent, that.textContent) - && Objects.equals(metadata, that.metadata); + } + return this.messageType == that.messageType && Objects.equals(this.textContent, that.textContent) + && Objects.equals(this.metadata, that.metadata); } @Override public int hashCode() { - return Objects.hash(messageType, textContent, metadata); + return Objects.hash(this.messageType, this.textContent, this.metadata); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java index fef22f67727..9a974ec2da1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import java.util.List; @@ -34,9 +35,6 @@ */ public class AssistantMessage extends AbstractMessage { - public record ToolCall(String id, String type, String name, String arguments) { - } - private final List toolCalls; public AssistantMessage(String content) { @@ -63,24 +61,31 @@ public boolean hasToolCalls() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AssistantMessage that)) + } + if (!(o instanceof AssistantMessage that)) { return false; - if (!super.equals(o)) + } + if (!super.equals(o)) { return false; - return Objects.equals(toolCalls, that.toolCalls); + } + return Objects.equals(this.toolCalls, that.toolCalls); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), toolCalls); + return Objects.hash(super.hashCode(), this.toolCalls); } @Override public String toString() { - return "AssistantMessage [messageType=" + messageType + ", toolCalls=" + toolCalls + ", textContent=" - + textContent + ", metadata=" + metadata + "]"; + return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + this.toolCalls + ", textContent=" + + this.textContent + ", metadata=" + this.metadata + "]"; + } + + public record ToolCall(String id, String type, String name, String arguments) { + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java index bdd11e9a64a..089b88b8a8b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import org.springframework.ai.model.Content; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java index 7603ab39f41..876b004eebe 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; /** @@ -56,10 +57,6 @@ public enum MessageType { this.value = value; } - public String getValue() { - return value; - } - public static MessageType fromValue(String value) { for (MessageType messageType : MessageType.values()) { if (messageType.getValue().equals(value)) { @@ -69,4 +66,8 @@ public static MessageType fromValue(String value) { throw new IllegalArgumentException("Invalid MessageType value: " + value); } + public String getValue() { + return this.value; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java index ddcff796678..e673de98a69 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import java.util.Map; @@ -44,24 +45,27 @@ public String getContent() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof SystemMessage that)) + } + if (!(o instanceof SystemMessage that)) { return false; - if (!super.equals(o)) + } + if (!super.equals(o)) { return false; - return Objects.equals(textContent, that.textContent); + } + return Objects.equals(this.textContent, that.textContent); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), textContent); + return Objects.hash(super.hashCode(), this.textContent); } @Override public String toString() { - return "SystemMessage{" + "textContent='" + textContent + '\'' + ", messageType=" + messageType + ", metadata=" - + metadata + '}'; + return "SystemMessage{" + "textContent='" + this.textContent + '\'' + ", messageType=" + this.messageType + + ", metadata=" + this.metadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java index 42f91f9df54..47da252180f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import java.util.List; @@ -28,9 +29,6 @@ */ public class ToolResponseMessage extends AbstractMessage { - public record ToolResponse(String id, String name, String responseData) { - }; - protected final List responses; public ToolResponseMessage(List responses) { @@ -48,24 +46,31 @@ public List getResponses() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ToolResponseMessage that)) + } + if (!(o instanceof ToolResponseMessage that)) { return false; - if (!super.equals(o)) + } + if (!super.equals(o)) { return false; - return Objects.equals(responses, that.responses); + } + return Objects.equals(this.responses, that.responses); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), responses); + return Objects.hash(super.hashCode(), this.responses); } @Override public String toString() { - return "ToolResponseMessage{" + "responses=" + responses + ", messageType=" + messageType + ", metadata=" - + metadata + '}'; + return "ToolResponseMessage{" + "responses=" + this.responses + ", messageType=" + this.messageType + + ", metadata=" + this.metadata + '}'; + } + + public record ToolResponse(String id, String name, String responseData) { + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 53c32425722..5a7e7db5794 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import java.util.ArrayList; @@ -69,8 +70,8 @@ public List getMedia(String... dummy) { @Override public String toString() { - return "UserMessage{" + "content='" + getContent() + '\'' + ", properties=" + metadata + ", messageType=" - + messageType + '}'; + return "UserMessage{" + "content='" + getContent() + '\'' + ", properties=" + this.metadata + ", messageType=" + + this.messageType + '}'; } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java index 744c3fdab3e..77728e276c8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import org.springframework.ai.model.ResultMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java index f58bc5d240b..20126472b92 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import java.util.Map; @@ -20,6 +21,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.model.AbstractResponseMetadata; import org.springframework.ai.model.ResponseMetadata; @@ -36,7 +38,8 @@ public class ChatResponseMetadata extends AbstractResponseMetadata implements Re private final static Logger logger = LoggerFactory.getLogger(ChatResponseMetadata.class); private String id = ""; // Set to blank to preserve backward compat with previous - // interface default methods + + // interface default methods private String model = ""; @@ -46,6 +49,10 @@ public class ChatResponseMetadata extends AbstractResponseMetadata implements Re private PromptMetadata promptMetadata = PromptMetadata.empty(); + public static Builder builder() { + return new Builder(); + } + /** * A unique identifier for the chat completion operation. * @return unique operation identifier. @@ -88,6 +95,29 @@ public PromptMetadata getPromptMetadata() { return this.promptMetadata; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ChatResponseMetadata that)) { + return false; + } + return Objects.equals(this.id, that.id) && Objects.equals(this.model, that.model) + && Objects.equals(this.rateLimit, that.rateLimit) && Objects.equals(this.usage, that.usage) + && Objects.equals(this.promptMetadata, that.promptMetadata); + } + + @Override + public int hashCode() { + return Objects.hash(this.id, this.model, this.rateLimit, this.usage, this.promptMetadata); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getId(), getUsage(), getRateLimit()); + } + public static class Builder { private final ChatResponseMetadata chatResponseMetadata; @@ -145,29 +175,4 @@ public ChatResponseMetadata build() { } - public static Builder builder() { - return new Builder(); - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof ChatResponseMetadata that)) - return false; - return Objects.equals(this.id, that.id) && Objects.equals(this.model, that.model) - && Objects.equals(this.rateLimit, that.rateLimit) && Objects.equals(this.usage, that.usage) - && Objects.equals(this.promptMetadata, that.promptMetadata); - } - - @Override - public int hashCode() { - return Objects.hash(this.id, this.model, this.rateLimit, this.usage, this.promptMetadata); - } - - @Override - public String toString() { - return AI_METADATA_STRING.formatted(getId(), getUsage(), getRateLimit()); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java index 5f5ee9c5176..a9fa52a30db 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.Objects; - /** * Default implementation of the {@link Usage} interface. * @@ -66,19 +67,19 @@ public DefaultUsage(@JsonProperty("promptTokens") Long promptTokens, @Override @JsonProperty("promptTokens") public Long getPromptTokens() { - return promptTokens; + return this.promptTokens; } @Override @JsonProperty("generationTokens") public Long getGenerationTokens() { - return generationTokens; + return this.generationTokens; } @Override @JsonProperty("totalTokens") public Long getTotalTokens() { - return totalTokens; + return this.totalTokens; } private Long calculateTotalTokens(Long promptTokens, Long generationTokens) { @@ -87,25 +88,27 @@ private Long calculateTotalTokens(Long promptTokens, Long generationTokens) { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } DefaultUsage that = (DefaultUsage) o; - return Objects.equals(promptTokens, that.promptTokens) - && Objects.equals(generationTokens, that.generationTokens) - && Objects.equals(totalTokens, that.totalTokens); + return Objects.equals(this.promptTokens, that.promptTokens) + && Objects.equals(this.generationTokens, that.generationTokens) + && Objects.equals(this.totalTokens, that.totalTokens); } @Override public int hashCode() { - return Objects.hash(promptTokens, generationTokens, totalTokens); + return Objects.hash(this.promptTokens, this.generationTokens, this.totalTokens); } @Override public String toString() { - return "DefaultUsage{" + "promptTokens=" + promptTokens + ", generationTokens=" + generationTokens - + ", totalTokens=" + totalTokens + '}'; + return "DefaultUsage{" + "promptTokens=" + this.promptTokens + ", generationTokens=" + this.generationTokens + + ", totalTokens=" + this.totalTokens + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyRateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyRateLimit.java index 80eb2462c8c..0506dbbf268 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyRateLimit.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyRateLimit.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import java.time.Duration; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java index 48cf590e716..b9cdaf87249 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java index bbc32a7913b..c78e61c27f2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import java.util.Arrays; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java index 68938021976..9cedeecd735 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import java.time.Duration; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java index d2dffc808bb..887bfbaa4c5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java index e32b4da6daa..38ba886f69a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.model; import java.util.ArrayList; -import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java index 5f2687bca94..e72a1c23e2c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.model; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.Prompt; +package org.springframework.ai.chat.model; import java.util.Arrays; @@ -24,6 +22,8 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Model; public interface ChatModel extends Model, StreamingChatModel { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java index 58d75e2182a..657a4ef9f4d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.model; import java.util.List; @@ -20,9 +21,9 @@ import java.util.Objects; import java.util.Set; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.model.ModelResponse; import org.springframework.util.CollectionUtils; -import org.springframework.ai.chat.metadata.ChatResponseMetadata; /** * The chat completion (e.g. generation) response returned by an AI provider. @@ -57,6 +58,10 @@ public ChatResponse(List generations, ChatResponseMetadata chatRespo this.generations = List.copyOf(generations); } + public static ChatResponse.Builder builder() { + return new ChatResponse.Builder(); + } + /** * The {@link List} of {@link Generation generated outputs}. *

@@ -91,29 +96,27 @@ public ChatResponseMetadata getMetadata() { @Override public String toString() { - return "ChatResponse [metadata=" + chatResponseMetadata + ", generations=" + generations + "]"; + return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]"; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ChatResponse that)) + } + if (!(o instanceof ChatResponse that)) { return false; - return Objects.equals(chatResponseMetadata, that.chatResponseMetadata) - && Objects.equals(generations, that.generations); + } + return Objects.equals(this.chatResponseMetadata, that.chatResponseMetadata) + && Objects.equals(this.generations, that.generations); } @Override public int hashCode() { - return Objects.hash(chatResponseMetadata, generations); - } - - public static ChatResponse.Builder builder() { - return new ChatResponse.Builder(); + return Objects.hash(this.chatResponseMetadata, this.generations); } - public static class Builder { + public static final class Builder { private List generations; @@ -149,7 +152,7 @@ public Builder withGenerations(List generations) { } public ChatResponse build() { - return new ChatResponse(generations, chatResponseMetadataBuilder.build()); + return new ChatResponse(this.generations, this.chatResponseMetadataBuilder.build()); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/Generation.java index 9d98d15cd0d..210935eadf5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/Generation.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.model; import java.util.Map; @@ -82,10 +83,12 @@ public Generation withGenerationMetadata(@Nullable ChatGenerationMetadata chatGe @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Generation that)) + } + if (!(o instanceof Generation that)) { return false; + } return Objects.equals(this.assistantMessage, that.assistantMessage) && Objects.equals(this.chatGenerationMetadata, that.chatGenerationMetadata); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 1c6bfc70225..fb0a1395a23 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; @@ -34,8 +36,6 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - /** * Helper that for streaming chat responses, aggregate the chat response messages into a * single AssistantMessage. Job is performed in parallel to the chat response processing. @@ -188,6 +188,7 @@ public Long getGenerationTokens() { public Long getTotalTokens() { return totalTokens(); } + } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java index 2eab40e4532..9105add8c36 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.model; import java.util.Arrays; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java index 2d49e1ebcba..5ba3a60ebf6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.model; import java.util.Collections; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java index 72cb7d82591..c68227bb50f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationFilter; + import org.springframework.ai.observation.tracing.TracingHelper; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandler.java index 9404c25567b..59612fa263a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import io.micrometer.observation.Observation; @@ -21,6 +22,7 @@ import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; + import org.springframework.ai.observation.conventions.AiObservationAttributes; import org.springframework.ai.observation.conventions.AiObservationEventNames; import org.springframework.ai.observation.tracing.TracingHelper; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java index 1604e0451ac..9b19d4199dc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationHandler; + import org.springframework.ai.model.observation.ModelUsageMetricsGenerator; /** @@ -38,7 +40,8 @@ public ChatModelMeterObservationHandler(MeterRegistry meterRegistry) { public void onStop(ChatModelObservationContext context) { if (context.getResponse() != null && context.getResponse().getMetadata() != null && context.getResponse().getMetadata().getUsage() != null) { - ModelUsageMetricsGenerator.generate(context.getResponse().getMetadata().getUsage(), context, meterRegistry); + ModelUsageMetricsGenerator.generate(context.getResponse().getMetadata().getUsage(), context, + this.meterRegistry); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessor.java index 2e13b571a58..3de4a321532 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import org.springframework.ai.model.Content; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.util.List; - /** * Utilities to process the prompt and completion content in observations for chat models. * @@ -28,6 +29,9 @@ */ public final class ChatModelObservationContentProcessor { + private ChatModelObservationContentProcessor() { + } + public static List prompt(ChatModelObservationContext context) { if (CollectionUtils.isEmpty(context.getRequest().getInstructions())) { return List.of(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java index eb20f161a62..525f5fab355 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import org.springframework.ai.chat.model.ChatResponse; @@ -40,15 +41,15 @@ public class ChatModelObservationContext extends ModelObservationContext stop) { + this.options.setStopSequences(stop); + return this; + } + + public ChatOptionsBuilder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public ChatOptionsBuilder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public ChatOptionsBuilder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public ChatOptions build() { + return this.options; + } private static class DefaultChatOptions implements ChatOptions { @@ -39,7 +93,7 @@ private static class DefaultChatOptions implements ChatOptions { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -48,7 +102,7 @@ public void setModel(String model) { @Override public Double getFrequencyPenalty() { - return frequencyPenalty; + return this.frequencyPenalty; } public void setFrequencyPenalty(Double frequencyPenalty) { @@ -57,7 +111,7 @@ public void setFrequencyPenalty(Double frequencyPenalty) { @Override public Integer getMaxTokens() { - return maxTokens; + return this.maxTokens; } public void setMaxTokens(Integer maxTokens) { @@ -66,7 +120,7 @@ public void setMaxTokens(Integer maxTokens) { @Override public Double getPresencePenalty() { - return presencePenalty; + return this.presencePenalty; } public void setPresencePenalty(Double presencePenalty) { @@ -75,7 +129,7 @@ public void setPresencePenalty(Double presencePenalty) { @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } public void setStopSequences(List stopSequences) { @@ -84,7 +138,7 @@ public void setStopSequences(List stopSequences) { @Override public Double getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Double temperature) { @@ -93,7 +147,7 @@ public void setTemperature(Double temperature) { @Override public Integer getTopK() { - return topK; + return this.topK; } public void setTopK(Integer topK) { @@ -102,7 +156,7 @@ public void setTopK(Integer topK) { @Override public Double getTopP() { - return topP; + return this.topP; } public void setTopP(Double topP) { @@ -124,57 +178,4 @@ public ChatOptions copy() { } - private final DefaultChatOptions options = new DefaultChatOptions(); - - private ChatOptionsBuilder() { - } - - public static ChatOptionsBuilder builder() { - return new ChatOptionsBuilder(); - } - - public ChatOptionsBuilder withModel(String model) { - options.setModel(model); - return this; - } - - public ChatOptionsBuilder withFrequencyPenalty(Double frequencyPenalty) { - options.setFrequencyPenalty(frequencyPenalty); - return this; - } - - public ChatOptionsBuilder withMaxTokens(Integer maxTokens) { - options.setMaxTokens(maxTokens); - return this; - } - - public ChatOptionsBuilder withPresencePenalty(Double presencePenalty) { - options.setPresencePenalty(presencePenalty); - return this; - } - - public ChatOptionsBuilder withStopSequences(List stop) { - options.setStopSequences(stop); - return this; - } - - public ChatOptionsBuilder withTemperature(Double temperature) { - options.setTemperature(temperature); - return this; - } - - public ChatOptionsBuilder withTopK(Integer topK) { - options.setTopK(topK); - return this; - } - - public ChatOptionsBuilder withTopP(Double topP) { - options.setTopP(topP); - return this; - } - - public ChatOptions build() { - return options; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java index 4db4aee5746..9183161871c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.prompt; -import org.springframework.ai.chat.messages.Message; +package org.springframework.ai.chat.prompt; import java.util.ArrayList; import java.util.List; import java.util.Map; +import org.springframework.ai.chat.messages.Message; + /** * A PromptTemplate that lets you specify the role as a string should the current * implementations and their roles not suffice for your needs. @@ -36,7 +37,7 @@ public ChatPromptTemplate(List promptTemplates) { @Override public String render() { StringBuilder sb = new StringBuilder(); - for (PromptTemplate promptTemplate : promptTemplates) { + for (PromptTemplate promptTemplate : this.promptTemplates) { sb.append(promptTemplate.render()); } return sb.toString(); @@ -45,7 +46,7 @@ public String render() { @Override public String render(Map model) { StringBuilder sb = new StringBuilder(); - for (PromptTemplate promptTemplate : promptTemplates) { + for (PromptTemplate promptTemplate : this.promptTemplates) { sb.append(promptTemplate.render(model)); } return sb.toString(); @@ -54,7 +55,7 @@ public String render(Map model) { @Override public List createMessages() { List messages = new ArrayList<>(); - for (PromptTemplate promptTemplate : promptTemplates) { + for (PromptTemplate promptTemplate : this.promptTemplates) { messages.add(promptTemplate.createMessage()); } return messages; @@ -63,7 +64,7 @@ public List createMessages() { @Override public List createMessages(Map model) { List messages = new ArrayList<>(); - for (PromptTemplate promptTemplate : promptTemplates) { + for (PromptTemplate promptTemplate : this.promptTemplates) { messages.add(promptTemplate.createMessage(model)); } return messages; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java index 913c18c85fd..3a8b368893a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; public class FunctionPromptTemplate extends PromptTemplate { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 3d36b1bfbf4..743314f8d4a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; import java.util.ArrayList; @@ -91,10 +92,12 @@ public String toString() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Prompt prompt)) + } + if (!(o instanceof Prompt prompt)) { return false; + } return Objects.equals(this.messages, prompt.messages) && Objects.equals(this.chatOptions, prompt.chatOptions); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java index 162b0d59ffa..852089e2b23 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; import java.io.IOException; @@ -30,22 +31,22 @@ import org.stringtemplate.v4.ST; import org.stringtemplate.v4.compiler.STLexer; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.model.Media; import org.springframework.core.io.Resource; import org.springframework.util.StreamUtils; public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions { - private ST st; - - private Map dynamicModel = new HashMap<>(); - protected String template; protected TemplateFormat templateFormat = TemplateFormat.ST; + private ST st; + + private Map dynamicModel = new HashMap<>(); + public PromptTemplate(Resource resource) { try (InputStream inputStream = resource.getInputStream()) { this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); @@ -122,7 +123,7 @@ public TemplateFormat getTemplateFormat() { @Override public String render() { validate(this.dynamicModel); - return st.render(); + return this.st.render(); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java index 872d35f4c63..76d5ef01748 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java index dd4424d0731..120cd87aa4f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.prompt; -import org.springframework.ai.chat.messages.Message; +package org.springframework.ai.chat.prompt; import java.util.List; import java.util.Map; +import org.springframework.ai.chat.messages.Message; + public interface PromptTemplateChatActions { List createMessages(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java index 8edcd36dab5..c87507b3c6b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.prompt; -import org.springframework.ai.model.Media; -import org.springframework.ai.chat.messages.Message; +package org.springframework.ai.chat.prompt; import java.util.List; import java.util.Map; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.model.Media; + public interface PromptTemplateMessageActions { Message createMessage(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java index bd81ed4ddf1..81be88b9336 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java index 8ac1aa85e42..18b1629dbec 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; +import java.util.Map; + import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.core.io.Resource; -import java.util.Map; - public class SystemPromptTemplate extends PromptTemplate { public SystemPromptTemplate(String template) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java index fe13fcf7d8a..a174300e9cc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; public enum TemplateFormat { @@ -25,10 +26,6 @@ public enum TemplateFormat { this.value = value; } - public String getValue() { - return value; - } - public static TemplateFormat fromValue(String value) { for (TemplateFormat templateFormat : TemplateFormat.values()) { if (templateFormat.getValue().equals(value)) { @@ -38,4 +35,8 @@ public static TemplateFormat fromValue(String value) { throw new IllegalArgumentException("Invalid TemplateFormat value: " + value); } + public String getValue() { + return this.value; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java index 6f209fa863a..b4e4d0868a6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import org.springframework.core.convert.support.DefaultConversionService; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractMessageOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractMessageOutputConverter.java index 05077025cb7..7a22dfc5553 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractMessageOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractMessageOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import org.springframework.messaging.converter.MessageConverter; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java index 424d291de7c..00c948372aa 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,22 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.converter; -import static com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON; -import static com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12; +package org.springframework.ai.converter; import java.lang.reflect.Type; import java.util.Objects; -import com.fasterxml.jackson.databind.json.JsonMapper; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.springframework.ai.util.JacksonUtils; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.lang.NonNull; - import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.util.DefaultIndenter; @@ -37,12 +27,19 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; +import com.fasterxml.jackson.databind.json.JsonMapper; import com.github.victools.jsonschema.generator.Option; import com.github.victools.jsonschema.generator.SchemaGenerator; import com.github.victools.jsonschema.generator.SchemaGeneratorConfig; import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; import com.github.victools.jsonschema.module.jackson.JacksonModule; import com.github.victools.jsonschema.module.jackson.JacksonOption; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.util.JacksonUtils; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.lang.NonNull; /** * An implementation of {@link StructuredOutputConverter} that transforms the LLM output @@ -62,9 +59,6 @@ public class BeanOutputConverter implements StructuredOutputConverter { private final Logger logger = LoggerFactory.getLogger(BeanOutputConverter.class); - /** Holds the generated JSON schema for the target type. */ - private String jsonSchema; - /** * The target class type reference to which the output will be converted. */ @@ -73,6 +67,9 @@ public class BeanOutputConverter implements StructuredOutputConverter { /** The object mapper used for deserialization and other JSON operations. */ private final ObjectMapper objectMapper; + /** Holds the generated JSON schema for the target type. */ + private String jsonSchema; + /** * Constructor to initialize with the target type's class. * @param clazz The target type's class. @@ -110,21 +107,6 @@ public BeanOutputConverter(ParameterizedTypeReference typeRef, ObjectMapper o this(new CustomizedTypeReference<>(typeRef), objectMapper); } - private static class CustomizedTypeReference extends TypeReference { - - private final Type type; - - CustomizedTypeReference(ParameterizedTypeReference typeRef) { - this.type = typeRef.getType(); - } - - @Override - public Type getType() { - return this.type; - } - - } - /** * Constructor to initialize with the target class type reference, a custom object * mapper, and a line endings normalizer to ensure consistent line endings on any @@ -144,7 +126,9 @@ private BeanOutputConverter(TypeReference typeRef, ObjectMapper objectMapper) */ private void generateSchema() { JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED); - SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(DRAFT_2020_12, PLAIN_JSON) + SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder( + com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, + com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON) .with(jacksonModule) .with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT); SchemaGeneratorConfig config = configBuilder.build(); @@ -156,7 +140,7 @@ private void generateSchema() { this.jsonSchema = objectWriter.writeValueAsString(jsonNode); } catch (JsonProcessingException e) { - logger.error("Could not pretty print json schema for jsonNode: " + jsonNode); + this.logger.error("Could not pretty print json schema for jsonNode: " + jsonNode); throw new RuntimeException("Could not pretty print json schema for " + this.typeRef, e); } } @@ -192,7 +176,8 @@ public T convert(@NonNull String text) { return (T) this.objectMapper.readValue(text, this.typeRef); } catch (JsonProcessingException e) { - logger.error("Could not parse the given text to the desired target type:" + text + " into " + this.typeRef); + this.logger + .error("Could not parse the given text to the desired target type:" + text + " into " + this.typeRef); throw new RuntimeException(e); } } @@ -234,4 +219,19 @@ public String getJsonSchema() { return this.jsonSchema; } + private static class CustomizedTypeReference extends TypeReference { + + private final Type type; + + CustomizedTypeReference(ParameterizedTypeReference typeRef) { + this.type = typeRef.getType(); + } + + @Override + public Type getType() { + return this.type; + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/FormatProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/FormatProvider.java index 9afbc14e32c..eea1e89ced0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/FormatProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/FormatProvider.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/ListOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/ListOutputConverter.java index 65a214f0298..3a16e275888 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/ListOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/ListOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/MapOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/MapOutputConverter.java index 682e6fcf7bb..f5100ebfb50 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/MapOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/MapOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import java.nio.charset.StandardCharsets; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/README.md b/spring-ai-core/src/main/java/org/springframework/ai/converter/README.md index 125f9b4f24d..3f9c03d0b5f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/README.md +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/README.md @@ -8,6 +8,8 @@ It may be a correct JSON, but it isn’t a JSON data structure. It is just a string. Also, asking "for JSON" as part of the prompt isn’t 100% accurate. -This intricacy has led to the emergence of a specialized field involving the creation of prompts to yield the intended output, followed by converting the resulting simple string into a usable data structure for application integration. +This intricacy has led to the emergence of a specialized field involving the creation of prompts to yield the intended +output, followed by converting the resulting simple string into a usable data structure for application integration. -Structure output conversion employs meticulously crafted prompts, often necessitating multiple interactions with the model to achieve the desired formatting. +Structure output conversion employs meticulously crafted prompts, often necessitating multiple interactions with the +model to achieve the desired formatting. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/StructuredOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/StructuredOutputConverter.java index 4756468a8ad..40a5f54839b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/StructuredOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/StructuredOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import org.springframework.core.convert.converter.Converter; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/ContentFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/ContentFormatter.java index 8b5ec57c305..a0a2e1d2cc0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/ContentFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/ContentFormatter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java index 7db065225ab..570b3afb47b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.ArrayList; @@ -30,7 +31,7 @@ /** * @author Christian Tzolov */ -public class DefaultContentFormatter implements ContentFormatter { +public final class DefaultContentFormatter implements ContentFormatter { private static final String TEMPLATE_CONTENT_PLACEHOLDER = "{content}"; @@ -74,6 +75,14 @@ public class DefaultContentFormatter implements ContentFormatter { */ private final List excludedEmbedMetadataKeys; + private DefaultContentFormatter(Builder builder) { + this.metadataTemplate = builder.metadataTemplate; + this.metadataSeparator = builder.metadataSeparator; + this.textTemplate = builder.textTemplate; + this.excludedInferenceMetadataKeys = builder.excludedInferenceMetadataKeys; + this.excludedEmbedMetadataKeys = builder.excludedEmbedMetadataKeys; + } + /** * Start building a new configuration. * @return The entry point for creating a new configuration. @@ -90,15 +99,71 @@ public static DefaultContentFormatter defaultConfig() { return builder().build(); } - private DefaultContentFormatter(Builder builder) { - this.metadataTemplate = builder.metadataTemplate; - this.metadataSeparator = builder.metadataSeparator; - this.textTemplate = builder.textTemplate; - this.excludedInferenceMetadataKeys = builder.excludedInferenceMetadataKeys; - this.excludedEmbedMetadataKeys = builder.excludedEmbedMetadataKeys; + @Override + public String format(Document document, MetadataMode metadataMode) { + + var metadata = metadataFilter(document.getMetadata(), metadataMode); + + var metadataText = metadata.entrySet() + .stream() + .map(metadataEntry -> this.metadataTemplate.replace(TEMPLATE_KEY_PLACEHOLDER, metadataEntry.getKey()) + .replace(TEMPLATE_VALUE_PLACEHOLDER, metadataEntry.getValue().toString())) + .collect(Collectors.joining(this.metadataSeparator)); + + return this.textTemplate.replace(TEMPLATE_METADATA_STRING_PLACEHOLDER, metadataText) + .replace(TEMPLATE_CONTENT_PLACEHOLDER, document.getContent()); } - public static class Builder { + /** + * Filters the metadata by the configured MetadataMode. + * @param metadata Document metadata. + * @return Returns the filtered by configured mode metadata. + */ + protected Map metadataFilter(Map metadata, MetadataMode metadataMode) { + + if (metadataMode == MetadataMode.ALL) { + return new HashMap(metadata); + } + if (metadataMode == MetadataMode.NONE) { + return new HashMap(Collections.emptyMap()); + } + + Set usableMetadataKeys = new HashSet<>(metadata.keySet()); + + if (metadataMode == MetadataMode.INFERENCE) { + usableMetadataKeys.removeAll(this.excludedInferenceMetadataKeys); + } + else if (metadataMode == MetadataMode.EMBED) { + usableMetadataKeys.removeAll(this.excludedEmbedMetadataKeys); + } + + return new HashMap(metadata.entrySet() + .stream() + .filter(e -> usableMetadataKeys.contains(e.getKey())) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); + } + + public String getMetadataTemplate() { + return this.metadataTemplate; + } + + public String getMetadataSeparator() { + return this.metadataSeparator; + } + + public String getTextTemplate() { + return this.textTemplate; + } + + public List getExcludedInferenceMetadataKeys() { + return Collections.unmodifiableList(this.excludedInferenceMetadataKeys); + } + + public List getExcludedEmbedMetadataKeys() { + return Collections.unmodifiableList(this.excludedEmbedMetadataKeys); + } + + public static final class Builder { private String metadataTemplate = DEFAULT_METADATA_TEMPLATE; @@ -199,68 +264,4 @@ public DefaultContentFormatter build() { } - @Override - public String format(Document document, MetadataMode metadataMode) { - - var metadata = metadataFilter(document.getMetadata(), metadataMode); - - var metadataText = metadata.entrySet() - .stream() - .map(metadataEntry -> this.metadataTemplate.replace(TEMPLATE_KEY_PLACEHOLDER, metadataEntry.getKey()) - .replace(TEMPLATE_VALUE_PLACEHOLDER, metadataEntry.getValue().toString())) - .collect(Collectors.joining(this.metadataSeparator)); - - return this.textTemplate.replace(TEMPLATE_METADATA_STRING_PLACEHOLDER, metadataText) - .replace(TEMPLATE_CONTENT_PLACEHOLDER, document.getContent()); - } - - /** - * Filters the metadata by the configured MetadataMode. - * @param metadata Document metadata. - * @return Returns the filtered by configured mode metadata. - */ - protected Map metadataFilter(Map metadata, MetadataMode metadataMode) { - - if (metadataMode == MetadataMode.ALL) { - return new HashMap(metadata); - } - if (metadataMode == MetadataMode.NONE) { - return new HashMap(Collections.emptyMap()); - } - - Set usableMetadataKeys = new HashSet<>(metadata.keySet()); - - if (metadataMode == MetadataMode.INFERENCE) { - usableMetadataKeys.removeAll(this.excludedInferenceMetadataKeys); - } - else if (metadataMode == MetadataMode.EMBED) { - usableMetadataKeys.removeAll(this.excludedEmbedMetadataKeys); - } - - return new HashMap(metadata.entrySet() - .stream() - .filter(e -> usableMetadataKeys.contains(e.getKey())) - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); - } - - public String getMetadataTemplate() { - return this.metadataTemplate; - } - - public String getMetadataSeparator() { - return this.metadataSeparator; - } - - public String getTextTemplate() { - return this.textTemplate; - } - - public List getExcludedInferenceMetadataKeys() { - return Collections.unmodifiableList(this.excludedInferenceMetadataKeys); - } - - public List getExcludedEmbedMetadataKeys() { - return Collections.unmodifiableList(this.excludedEmbedMetadataKeys); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java index 7722a4dfe29..e666857dd46 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.document.id.IdGenerator; import org.springframework.ai.document.id.RandomIdGenerator; import org.springframework.ai.model.Media; @@ -26,12 +34,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - /** * A document is a container for the content and metadata of a document. It also contains * the document's unique ID and an optional embedding. @@ -48,12 +50,6 @@ public class Document implements MediaContent { */ private final String id; - /** - * Metadata for the document. It should not be nested and values should be restricted - * to string, int, float, boolean for simple use with Vector Dbs. - */ - private Map metadata; - /** * Document content. */ @@ -61,6 +57,12 @@ public class Document implements MediaContent { private final Collection media; + /** + * Metadata for the document. It should not be nested and values should be restricted + * to string, int, float, boolean for simple use with Vector Dbs. + */ + private Map metadata; + /** * Embedding of the document. Note: ephemeral field. */ @@ -109,72 +111,8 @@ public static Builder builder() { return new Builder(); } - public static class Builder { - - private String id; - - private String content = Document.EMPTY_TEXT; - - private List media = new ArrayList<>(); - - private Map metadata = new HashMap<>(); - - private IdGenerator idGenerator = new RandomIdGenerator(); - - public Builder withIdGenerator(IdGenerator idGenerator) { - Assert.notNull(idGenerator, "idGenerator must not be null"); - this.idGenerator = idGenerator; - return this; - } - - public Builder withId(String id) { - Assert.hasText(id, "id must not be null or empty"); - this.id = id; - return this; - } - - public Builder withContent(String content) { - Assert.notNull(content, "content must not be null"); - this.content = content; - return this; - } - - public Builder withMedia(List media) { - Assert.notNull(media, "media must not be null"); - this.media = media; - return this; - } - - public Builder withMedia(Media media) { - Assert.notNull(media, "media must not be null"); - this.media.add(media); - return this; - } - - public Builder withMetadata(Map metadata) { - Assert.notNull(metadata, "metadata must not be null"); - this.metadata = metadata; - return this; - } - - public Builder withMetadata(String key, Object value) { - Assert.notNull(key, "key must not be null"); - Assert.notNull(value, "value must not be null"); - this.metadata.put(key, value); - return this; - } - - public Document build() { - if (!StringUtils.hasText(this.id)) { - this.id = this.idGenerator.generateId(content, metadata); - } - return new Document(id, content, media, metadata); - } - - } - public String getId() { - return id; + return this.id; } @Override @@ -206,11 +144,24 @@ public String getFormattedContent(ContentFormatter formatter, MetadataMode metad return formatter.format(this, metadataMode); } + @Override + public Map getMetadata() { + return this.metadata; + } + + public float[] getEmbedding() { + return this.embedding; + } + public void setEmbedding(float[] embedding) { Assert.notNull(embedding, "embedding must not be null"); this.embedding = embedding; } + public ContentFormatter getContentFormatter() { + return this.contentFormatter; + } + /** * Replace the document's {@link ContentFormatter}. * @param contentFormatter new formatter to use. @@ -219,63 +170,123 @@ public void setContentFormatter(ContentFormatter contentFormatter) { this.contentFormatter = contentFormatter; } - @Override - public Map getMetadata() { - return this.metadata; - } - - public float[] getEmbedding() { - return this.embedding; - } - - public ContentFormatter getContentFormatter() { - return contentFormatter; - } - @Override public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((id == null) ? 0 : id.hashCode()); - result = prime * result + ((metadata == null) ? 0 : metadata.hashCode()); - result = prime * result + ((content == null) ? 0 : content.hashCode()); + result = prime * result + ((this.id == null) ? 0 : this.id.hashCode()); + result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode()); + result = prime * result + ((this.content == null) ? 0 : this.content.hashCode()); return result; } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } Document other = (Document) obj; - if (id == null) { - if (other.id != null) + if (this.id == null) { + if (other.id != null) { return false; + } } - else if (!id.equals(other.id)) + else if (!this.id.equals(other.id)) { return false; - if (metadata == null) { - if (other.metadata != null) + } + if (this.metadata == null) { + if (other.metadata != null) { return false; + } } - else if (!metadata.equals(other.metadata)) + else if (!this.metadata.equals(other.metadata)) { return false; - if (content == null) { - if (other.content != null) + } + if (this.content == null) { + if (other.content != null) { return false; + } } - else if (!content.equals(other.content)) + else if (!this.content.equals(other.content)) { return false; + } return true; } @Override public String toString() { - return "Document{" + "id='" + id + '\'' + ", metadata=" + metadata + ", content='" + content + '\'' + ", media=" - + media + '}'; + return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content + + '\'' + ", media=" + this.media + '}'; + } + + public static class Builder { + + private String id; + + private String content = Document.EMPTY_TEXT; + + private List media = new ArrayList<>(); + + private Map metadata = new HashMap<>(); + + private IdGenerator idGenerator = new RandomIdGenerator(); + + public Builder withIdGenerator(IdGenerator idGenerator) { + Assert.notNull(idGenerator, "idGenerator must not be null"); + this.idGenerator = idGenerator; + return this; + } + + public Builder withId(String id) { + Assert.hasText(id, "id must not be null or empty"); + this.id = id; + return this; + } + + public Builder withContent(String content) { + Assert.notNull(content, "content must not be null"); + this.content = content; + return this; + } + + public Builder withMedia(List media) { + Assert.notNull(media, "media must not be null"); + this.media = media; + return this; + } + + public Builder withMedia(Media media) { + Assert.notNull(media, "media must not be null"); + this.media.add(media); + return this; + } + + public Builder withMetadata(Map metadata) { + Assert.notNull(metadata, "metadata must not be null"); + this.metadata = metadata; + return this; + } + + public Builder withMetadata(String key, Object value) { + Assert.notNull(key, "key must not be null"); + Assert.notNull(value, "value must not be null"); + this.metadata.put(key, value); + return this; + } + + public Document build() { + if (!StringUtils.hasText(this.id)) { + this.id = this.idGenerator.generateId(this.content, this.metadata); + } + return new Document(this.id, this.content, this.media, this.metadata); + } + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentReader.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentReader.java index 75b4fe2b26b..f6179ca4ebb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentReader.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java index 50d0b4b13a9..618af3fc664 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentTransformer.java index 8c325a7bd0d..6f17faf7c12 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentTransformer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentTransformer.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentWriter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentWriter.java index 31aeaf905ab..a85fb49e54f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentWriter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentWriter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/MetadataMode.java b/spring-ai-core/src/main/java/org/springframework/ai/document/MetadataMode.java index 3d32a2b5dcc..733e1cbfbba 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/MetadataMode.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/MetadataMode.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; public enum MetadataMode { - ALL, EMBED, INFERENCE, NONE; + ALL, EMBED, INFERENCE, NONE -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/id/IdGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/document/id/IdGenerator.java index 198c114d52b..f9c43726b77 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/id/IdGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/id/IdGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/id/JdkSha256HexIdGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/document/id/JdkSha256HexIdGenerator.java index 1302a6dd3c8..ca561b0355e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/id/JdkSha256HexIdGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/id/JdkSha256HexIdGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; import java.io.ByteArrayOutputStream; @@ -86,7 +87,7 @@ private byte[] serializeToBytes(Object... contents) { MessageDigest getMessageDigest() { try { - return (MessageDigest) messageDigest.clone(); + return (MessageDigest) this.messageDigest.clone(); } catch (CloneNotSupportedException e) { throw new RuntimeException("Unsupported clone for MessageDigest.", e); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/id/RandomIdGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/document/id/RandomIdGenerator.java index 0920e9a0424..8c50c2e560c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/id/RandomIdGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/id/RandomIdGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; import java.util.UUID; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java index 1c6c1c374b0..8165e5d7aff 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.io.IOException; @@ -31,10 +32,10 @@ */ public abstract class AbstractEmbeddingModel implements EmbeddingModel { - protected final AtomicInteger embeddingDimensions = new AtomicInteger(-1); - private static Map KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions(); + protected final AtomicInteger embeddingDimensions = new AtomicInteger(-1); + /** * Return the dimension of the requested embedding generative name. If the generative * name is unknown uses the EmbeddingModel to perform a dummy EmbeddingModel#embed and diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java index 4f73cab0684..e354f1da87c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java index eb4a8354004..6237b98d9d2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import org.springframework.ai.model.Model; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java index 8227b291095..6fc754c1178 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.Arrays; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java index daaa20d0eee..1dabb36d40c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.Objects; @@ -56,7 +57,7 @@ public Embedding(float[] embedding, Integer index, EmbeddingResultMetadata metad */ @Override public float[] getOutput() { - return embedding; + return this.embedding; } /** @@ -75,17 +76,19 @@ public EmbeddingResultMetadata getMetadata() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } Embedding other = (Embedding) o; return Objects.equals(this.embedding, other.embedding) && Objects.equals(this.index, other.index); } @Override public int hashCode() { - return Objects.hash(embedding, index); + return Objects.hash(this.embedding, this.index); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java index 874fadfed81..e4785b867dc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.document.Document; import org.springframework.ai.model.Model; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.List; - /** * EmbeddingModel is a generic interface for embedding models. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java index 3fac8119034..f7461249f72 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import org.springframework.ai.model.ModelOptions; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptionsBuilder.java index ab13dff1aaf..cdb4fb999a7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptionsBuilder.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; /** * @author Thomas Vitale * @since 1.0.0 */ -public class EmbeddingOptionsBuilder { +public final class EmbeddingOptionsBuilder { private final DefaultEmbeddingOptions embeddingOptions = new DefaultEmbeddingOptions(); @@ -31,17 +32,17 @@ public static EmbeddingOptionsBuilder builder() { } public EmbeddingOptionsBuilder withModel(String model) { - embeddingOptions.setModel(model); + this.embeddingOptions.setModel(model); return this; } public EmbeddingOptionsBuilder withDimensions(Integer dimensions) { - embeddingOptions.setDimensions(dimensions); + this.embeddingOptions.setDimensions(dimensions); return this; } public EmbeddingOptions build() { - return embeddingOptions; + return this.embeddingOptions; } private static class DefaultEmbeddingOptions implements EmbeddingOptions { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java index e5512bfe2a1..70429783eca 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java index b8926256741..2ad2afac32f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.List; @@ -58,13 +59,13 @@ public EmbeddingResponse(List embeddings, EmbeddingResponseMetadata m * @return Get the embedding metadata. */ public EmbeddingResponseMetadata getMetadata() { - return metadata; + return this.metadata; } @Override public Embedding getResult() { - Assert.notEmpty(embeddings, "No embedding data available."); - return embeddings.get(0); + Assert.notEmpty(this.embeddings, "No embedding data available."); + return this.embeddings.get(0); } /** @@ -72,27 +73,29 @@ public Embedding getResult() { */ @Override public List getResults() { - return embeddings; + return this.embeddings; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } EmbeddingResponse that = (EmbeddingResponse) o; - return Objects.equals(embeddings, that.embeddings) && Objects.equals(metadata, that.metadata); + return Objects.equals(this.embeddings, that.embeddings) && Objects.equals(this.metadata, that.metadata); } @Override public int hashCode() { - return Objects.hash(embeddings, metadata); + return Objects.hash(this.embeddings, this.metadata); } @Override public String toString() { - return "EmbeddingResult{" + "data=" + embeddings + ", metadata=" + metadata + '}'; + return "EmbeddingResult{" + "data=" + this.embeddings + ", metadata=" + this.metadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java index 335ac0ae2b8..a9440dc7ae7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; +import java.util.Map; + import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.model.AbstractResponseMetadata; import org.springframework.ai.model.ResponseMetadata; -import java.util.Map; - /** * Common AI provider metadata returned in an embedding response. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java index 9b7df810b39..eb0dfead4d2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import org.springframework.ai.model.ResultMetadata; @@ -27,12 +28,6 @@ public class EmbeddingResultMetadata implements ResultMetadata { public static EmbeddingResultMetadata EMPTY = new EmbeddingResultMetadata(); - public enum ModalityType { - - TEXT, IMAGE, AUDIO, VIDEO; - - } - /** * The {@link MimeType} of the source data used to generate the embedding. */ @@ -75,6 +70,12 @@ public Object getDocumentData() { return this.documentData; } + public enum ModalityType { + + TEXT, IMAGE, AUDIO, VIDEO + + } + public static class ModalityUtils { private static MimeType TEXT_MIME_TYPE = MimeTypeUtils.parseMimeType("text/*"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java index 2ff2dce0e9a..298278c3581 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.ArrayList; @@ -20,13 +21,13 @@ import java.util.List; import java.util.Map; +import com.knuddels.jtokkit.api.EncodingType; + import org.springframework.ai.document.ContentFormatter; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator; import org.springframework.ai.tokenizer.TokenCountEstimator; - -import com.knuddels.jtokkit.api.EncodingType; import org.springframework.util.Assert; /** @@ -144,7 +145,7 @@ public List> batch(List documents) { for (Document document : documentTokens.keySet()) { Integer tokenCount = documentTokens.get(document); - if (currentSize + tokenCount > maxInputTokenCount) { + if (currentSize + tokenCount > this.maxInputTokenCount) { batches.add(currentBatch); currentBatch = new ArrayList<>(); currentSize = 0; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java index 6e4269fe9b4..6949f0e000d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; + import org.springframework.util.StringUtils; /** @@ -27,14 +29,14 @@ */ public class DefaultEmbeddingModelObservationConvention implements EmbeddingModelObservationConvention { + public static final String DEFAULT_NAME = "gen_ai.client.operation"; + private static final KeyValue REQUEST_MODEL_NONE = KeyValue .of(EmbeddingModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, KeyValue.NONE_VALUE); private static final KeyValue RESPONSE_MODEL_NONE = KeyValue .of(EmbeddingModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL, KeyValue.NONE_VALUE); - public static final String DEFAULT_NAME = "gen_ai.client.operation"; - @Override public String getName() { return DEFAULT_NAME; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandler.java index 8d5fb754b83..84a7b8f048c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationHandler; + import org.springframework.ai.model.observation.ModelUsageMetricsGenerator; /** @@ -38,7 +40,8 @@ public EmbeddingModelMeterObservationHandler(MeterRegistry meterRegistry) { public void onStop(EmbeddingModelObservationContext context) { if (context.getResponse() != null && context.getResponse().getMetadata() != null && context.getResponse().getMetadata().getUsage() != null) { - ModelUsageMetricsGenerator.generate(context.getResponse().getMetadata().getUsage(), context, meterRegistry); + ModelUsageMetricsGenerator.generate(context.getResponse().getMetadata().getUsage(), context, + this.meterRegistry); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java index 2b6b09c6771..9b46135ae7b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; import org.springframework.ai.embedding.EmbeddingOptions; @@ -44,15 +45,15 @@ public class EmbeddingModelObservationContext extends ModelObservationContext getDataList() { - return dataList; + return this.dataList; } public String getResponseContent() { - return responseContent; + return this.responseContent; } @Override public String toString() { - return "EvaluationRequest{" + "userText='" + userText + '\'' + ", dataList=" + dataList + ", chatResponse=" - + responseContent + '}'; + return "EvaluationRequest{" + "userText='" + this.userText + '\'' + ", dataList=" + this.dataList + + ", chatResponse=" + this.responseContent + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof EvaluationRequest that)) + } + if (!(o instanceof EvaluationRequest that)) { return false; - return Objects.equals(userText, that.userText) && Objects.equals(dataList, that.dataList) - && Objects.equals(responseContent, that.responseContent); + } + return Objects.equals(this.userText, that.userText) && Objects.equals(this.dataList, that.dataList) + && Objects.equals(this.responseContent, that.responseContent); } @Override public int hashCode() { - return Objects.hash(userText, dataList, responseContent); + return Objects.hash(this.userText, this.dataList, this.responseContent); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/EvaluationResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/EvaluationResponse.java index f866cb5e247..ead7fa565ff 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/EvaluationResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/EvaluationResponse.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.evaluation; import java.util.Map; @@ -29,40 +45,42 @@ public EvaluationResponse(boolean pass, String feedback, Map met } public boolean isPass() { - return pass; + return this.pass; } public float getScore() { - return score; + return this.score; } public String getFeedback() { - return feedback; + return this.feedback; } public Map getMetadata() { - return metadata; + return this.metadata; } @Override public String toString() { - return "EvaluationResponse{" + "pass=" + pass + ", score=" + score + ", feedback='" + feedback + '\'' - + ", metadata=" + metadata + '}'; + return "EvaluationResponse{" + "pass=" + this.pass + ", score=" + this.score + ", feedback='" + this.feedback + + '\'' + ", metadata=" + this.metadata + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof EvaluationResponse that)) + } + if (!(o instanceof EvaluationResponse that)) { return false; - return pass == that.pass && Float.compare(score, that.score) == 0 && Objects.equals(feedback, that.feedback) - && Objects.equals(metadata, that.metadata); + } + return this.pass == that.pass && Float.compare(this.score, that.score) == 0 + && Objects.equals(this.feedback, that.feedback) && Objects.equals(this.metadata, that.metadata); } @Override public int hashCode() { - return Objects.hash(pass, score, feedback, metadata); + return Objects.hash(this.pass, this.score, this.feedback, this.metadata); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/Evaluator.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/Evaluator.java index b14fe2adb90..9b12205d3dd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/Evaluator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/Evaluator.java @@ -1,11 +1,27 @@ -package org.springframework.ai.evaluation; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.springframework.ai.model.Content; -import org.springframework.util.StringUtils; +package org.springframework.ai.evaluation; import java.util.List; import java.util.stream.Collectors; +import org.springframework.ai.model.Content; +import org.springframework.util.StringUtils; + @FunctionalInterface public interface Evaluator { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/FactCheckingEvaluator.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/FactCheckingEvaluator.java index eb16f66c445..77bd676ed3d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/FactCheckingEvaluator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/FactCheckingEvaluator.java @@ -1,9 +1,25 @@ -package org.springframework.ai.evaluation; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.springframework.ai.chat.client.ChatClient; +package org.springframework.ai.evaluation; import java.util.Collections; +import org.springframework.ai.chat.client.ChatClient; + /** * The FactCheckingEvaluator class implements a method for evaluating the factual accuracy * of Large Language Model (LLM) responses against provided context. @@ -48,8 +64,8 @@ public class FactCheckingEvaluator implements Evaluator { private static final String DEFAULT_EVALUATION_PROMPT_TEXT = """ - Document: \\n {document}\\n - Claim: \\n {claim} + Document: \\n {document}\\n + Claim: \\n {claim} """; private final ChatClient.Builder chatClientBuilder; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java index 5a0ec203a6d..b85591333d8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java @@ -1,21 +1,37 @@ -package org.springframework.ai.evaluation; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.springframework.ai.chat.client.ChatClient; +package org.springframework.ai.evaluation; import java.util.Collections; +import org.springframework.ai.chat.client.ChatClient; + public class RelevancyEvaluator implements Evaluator { private static final String DEFAULT_EVALUATION_PROMPT_TEXT = """ - Your task is to evaluate if the response for the query - is in line with the context information provided.\\n - You have two options to answer. Either YES/ NO.\\n - Answer - YES, if the response for the query - is in line with context information otherwise NO.\\n - Query: \\n {query}\\n - Response: \\n {response}\\n - Context: \\n {context}\\n - Answer: " + Your task is to evaluate if the response for the query + is in line with the context information provided.\\n + You have two options to answer. Either YES/ NO.\\n + Answer - YES, if the response for the query + is in line with context information otherwise NO.\\n + Query: \\n {query}\\n + Response: \\n {response}\\n + Context: \\n {context}\\n + Answer: " """; private final ChatClient.Builder chatClientBuilder; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java index 8adc3677e09..bf1f683a16a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import java.util.Objects; @@ -35,7 +36,7 @@ public Image(String url, String b64Json) { } public String getUrl() { - return url; + return this.url; } public void setUrl(String url) { @@ -43,7 +44,7 @@ public void setUrl(String url) { } public String getB64Json() { - return b64Json; + return this.b64Json; } public void setB64Json(String b64Json) { @@ -52,21 +53,23 @@ public void setB64Json(String b64Json) { @Override public String toString() { - return "Image{" + "url='" + url + '\'' + ", b64Json='" + b64Json + '\'' + '}'; + return "Image{" + "url='" + this.url + '\'' + ", b64Json='" + this.b64Json + '\'' + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Image image)) + } + if (!(o instanceof Image image)) { return false; - return Objects.equals(url, image.url) && Objects.equals(b64Json, image.b64Json); + } + return Objects.equals(this.url, image.url) && Objects.equals(this.b64Json, image.b64Json); } @Override public int hashCode() { - return Objects.hash(url, b64Json); + return Objects.hash(this.url, this.b64Json); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java index 431afd81325..3f9425f5530 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelResult; @@ -34,17 +35,18 @@ public ImageGeneration(Image image, ImageGenerationMetadata imageGenerationMetad @Override public Image getOutput() { - return image; + return this.image; } @Override public ImageGenerationMetadata getMetadata() { - return imageGenerationMetadata; + return this.imageGenerationMetadata; } @Override public String toString() { - return "ImageGeneration{" + "imageGenerationMetadata=" + imageGenerationMetadata + ", image=" + image + '}'; + return "ImageGeneration{" + "imageGenerationMetadata=" + this.imageGenerationMetadata + ", image=" + this.image + + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java index 164f781d172..7a513390c8a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.ResultMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java index 2b298bb0715..72825b1a4ee 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import java.util.Objects; @@ -33,30 +34,32 @@ public ImageMessage(String text, Float weight) { } public String getText() { - return text; + return this.text; } public Float getWeight() { - return weight; + return this.weight; } @Override public String toString() { - return "ImageMessage{" + "text='" + text + '\'' + ", weight=" + weight + '}'; + return "ImageMessage{" + "text='" + this.text + '\'' + ", weight=" + this.weight + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ImageMessage that)) + } + if (!(o instanceof ImageMessage that)) { return false; - return Objects.equals(text, that.text) && Objects.equals(weight, that.weight); + } + return Objects.equals(this.text, that.text) && Objects.equals(this.weight, that.weight); } @Override public int hashCode() { - return Objects.hash(text, weight); + return Objects.hash(this.text, this.weight); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageModel.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageModel.java index 493da50bf15..466931a68bd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.Model; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java index af896202983..435f6fc62df 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelOptions; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java index 7917f3df550..30f1f010527 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,54 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; -public class ImageOptionsBuilder { +public final class ImageOptionsBuilder { + + private final DefaultImageModelOptions options = new DefaultImageModelOptions(); + + private ImageOptionsBuilder() { + + } + + public static ImageOptionsBuilder builder() { + return new ImageOptionsBuilder(); + } + + public ImageOptionsBuilder withN(Integer n) { + this.options.setN(n); + return this; + } + + public ImageOptionsBuilder withModel(String model) { + this.options.setModel(model); + return this; + } + + public ImageOptionsBuilder withResponseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + + public ImageOptionsBuilder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public ImageOptionsBuilder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public ImageOptionsBuilder withStyle(String style) { + this.options.setStyle(style); + return this; + } + + public ImageOptions build() { + return this.options; + } private static class DefaultImageModelOptions implements ImageOptions { @@ -33,7 +78,7 @@ private static class DefaultImageModelOptions implements ImageOptions { @Override public Integer getN() { - return n; + return this.n; } public void setN(Integer n) { @@ -42,7 +87,7 @@ public void setN(Integer n) { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -51,7 +96,7 @@ public void setModel(String model) { @Override public String getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(String responseFormat) { @@ -60,7 +105,7 @@ public void setResponseFormat(String responseFormat) { @Override public Integer getWidth() { - return width; + return this.width; } public void setWidth(Integer width) { @@ -69,7 +114,7 @@ public void setWidth(Integer width) { @Override public Integer getHeight() { - return height; + return this.height; } public void setHeight(Integer height) { @@ -78,7 +123,7 @@ public void setHeight(Integer height) { @Override public String getStyle() { - return style; + return this.style; } public void setStyle(String style) { @@ -87,48 +132,4 @@ public void setStyle(String style) { } - private final DefaultImageModelOptions options = new DefaultImageModelOptions(); - - private ImageOptionsBuilder() { - - } - - public static ImageOptionsBuilder builder() { - return new ImageOptionsBuilder(); - } - - public ImageOptionsBuilder withN(Integer n) { - options.setN(n); - return this; - } - - public ImageOptionsBuilder withModel(String model) { - options.setModel(model); - return this; - } - - public ImageOptionsBuilder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); - return this; - } - - public ImageOptionsBuilder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public ImageOptionsBuilder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public ImageOptionsBuilder withStyle(String style) { - options.setStyle(style); - return this; - } - - public ImageOptions build() { - return options; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java index 59ac64c818c..a212c2cf4f5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.image; -import org.springframework.ai.model.ModelRequest; +package org.springframework.ai.image; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.springframework.ai.model.ModelRequest; + public class ImagePrompt implements ModelRequest> { private final List messages; @@ -50,31 +51,34 @@ public ImagePrompt(String instructions) { @Override public List getInstructions() { - return messages; + return this.messages; } @Override public ImageOptions getOptions() { - return imageModelOptions; + return this.imageModelOptions; } @Override public String toString() { - return "NewImagePrompt{" + "messages=" + messages + ", imageModelOptions=" + imageModelOptions + '}'; + return "NewImagePrompt{" + "messages=" + this.messages + ", imageModelOptions=" + this.imageModelOptions + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ImagePrompt that)) + } + if (!(o instanceof ImagePrompt that)) { return false; - return Objects.equals(messages, that.messages) && Objects.equals(imageModelOptions, that.imageModelOptions); + } + return Objects.equals(this.messages, that.messages) + && Objects.equals(this.imageModelOptions, that.imageModelOptions); } @Override public int hashCode() { - return Objects.hash(messages, imageModelOptions); + return Objects.hash(this.messages, this.imageModelOptions); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java index b6d6c87b883..c4605d81890 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import java.util.List; @@ -67,7 +68,7 @@ public ImageResponse(List generations, ImageResponseMetadata im */ @Override public List getResults() { - return imageGenerations; + return this.imageGenerations; } /** @@ -78,7 +79,7 @@ public ImageGeneration getResult() { if (CollectionUtils.isEmpty(this.imageGenerations)) { return null; } - return imageGenerations.get(0); + return this.imageGenerations.get(0); } /** @@ -87,28 +88,30 @@ public ImageGeneration getResult() { */ @Override public ImageResponseMetadata getMetadata() { - return imageResponseMetadata; + return this.imageResponseMetadata; } @Override public String toString() { - return "ImageResponse [" + "imageResponseMetadata=" + imageResponseMetadata + ", imageGenerations=" - + imageGenerations + "]"; + return "ImageResponse [" + "imageResponseMetadata=" + this.imageResponseMetadata + ", imageGenerations=" + + this.imageGenerations + "]"; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ImageResponse that)) + } + if (!(o instanceof ImageResponse that)) { return false; - return Objects.equals(imageResponseMetadata, that.imageResponseMetadata) - && Objects.equals(imageGenerations, that.imageGenerations); + } + return Objects.equals(this.imageResponseMetadata, that.imageResponseMetadata) + && Objects.equals(this.imageGenerations, that.imageGenerations); } @Override public int hashCode() { - return Objects.hash(imageResponseMetadata, imageGenerations); + return Objects.hash(this.imageResponseMetadata, this.imageGenerations); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java index 5d694e54817..816c92b2809 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.MutableResponseMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java index 2413e51e41c..35cb6f51ba3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; + import org.springframework.util.StringUtils; /** @@ -27,11 +29,11 @@ */ public class DefaultImageModelObservationConvention implements ImageModelObservationConvention { + public static final String DEFAULT_NAME = "gen_ai.client.operation"; + private static final KeyValue REQUEST_MODEL_NONE = KeyValue .of(ImageModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, KeyValue.NONE_VALUE); - public static final String DEFAULT_NAME = "gen_ai.client.operation"; - @Override public String getName() { return DEFAULT_NAME; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/observation/ImageModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/image/observation/ImageModelObservationContext.java index 52846a41dba..34ba0f77055 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/observation/ImageModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/observation/ImageModelObservationContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; import org.springframework.ai.image.ImageOptions; @@ -40,19 +41,19 @@ public class ImageModelObservationContext extends ModelObservationContext doubleToFloat(final List doubles) { return doubles.stream().map(f -> f.floatValue()).toList(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java b/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java index 5391f1f98fd..fe5cd8212fa 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; +import java.io.IOException; +import java.net.URL; + import org.springframework.core.io.Resource; import org.springframework.util.Assert; import org.springframework.util.MimeType; -import java.io.IOException; -import java.net.URL; - /** * The Media class represents the data and metadata of a media attachment in a message. It * consists of a MIME type and the raw data. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/MediaContent.java b/spring-ai-core/src/main/java/org/springframework/ai/model/MediaContent.java index 4b436e82a66..933ded36b67 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/MediaContent.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/MediaContent.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; import java.util.Collection; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/Model.java b/spring-ai-core/src/main/java/org/springframework/ai/model/Model.java index 1671a3dc49f..391786543ae 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/Model.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/Model.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java index 0335f341eba..71382538b3d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java index 124818e80b4..10f54e02c2a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 8a6cc7d0da3..e049fb17f00 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; import java.beans.PropertyDescriptor; @@ -26,13 +27,6 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; -import org.springframework.ai.util.JacksonUtils; -import org.springframework.beans.BeanWrapper; -import org.springframework.beans.BeanWrapperImpl; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.ObjectUtils; - import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; @@ -53,6 +47,13 @@ import com.github.victools.jsonschema.module.jackson.JacksonOption; import com.github.victools.jsonschema.module.swagger2.Swagger2Module; +import org.springframework.ai.util.JacksonUtils; +import org.springframework.beans.BeanWrapper; +import org.springframework.beans.BeanWrapperImpl; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.ObjectUtils; + /** * Utility class for manipulating {@link ModelOptions} objects. * @@ -74,6 +75,10 @@ public abstract class ModelOptionsUtils { private static final AtomicReference SCHEMA_GENERATOR_CACHE = new AtomicReference<>(); + private static TypeReference> MAP_TYPE_REF = new TypeReference>() { + + }; + /** * Converts the given JSON string to a Map of String and Object. * @param json the JSON string to convert to a Map. @@ -88,9 +93,6 @@ public static Map jsonToMap(String json) { } } - private static TypeReference> MAP_TYPE_REF = new TypeReference>() { - }; - /** * Converts the given JSON string to an Object of the given type. * @param the type of the object to return. @@ -193,6 +195,7 @@ public static Map objectToMap(Object source) { try { String json = OBJECT_MAPPER.writeValueAsString(source); return OBJECT_MAPPER.readValue(json, new TypeReference>() { + }) .entrySet() .stream() @@ -356,7 +359,7 @@ public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz); if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI - // version of it). + // version of it). toUpperCaseTypeValues(node); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java index 94c2e8aefc5..7b86a850753 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** @@ -40,4 +41,4 @@ public interface ModelRequest { */ ModelOptions getOptions(); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java index f4a9bf83a05..5df8b8d2a82 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java index 6ee17815c74..f28a9dfc1eb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java index ac0c9254e75..106a90e6867 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java @@ -1,7 +1,20 @@ -package org.springframework.ai.model; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import io.micrometer.common.lang.NonNull; -import io.micrometer.common.lang.Nullable; +package org.springframework.ai.model; import java.util.Collections; import java.util.Map; @@ -9,6 +22,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import io.micrometer.common.lang.NonNull; +import io.micrometer.common.lang.Nullable; + public class MutableResponseMetadata implements ResponseMetadata { private final Map map = new ConcurrentHashMap<>(); @@ -120,7 +136,7 @@ public void clear() { } public Map getRawMap() { - return map; + return this.map; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java index 24e544d4f29..7b63e91a481 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.model; -import io.micrometer.common.lang.NonNull; -import io.micrometer.common.lang.Nullable; +package org.springframework.ai.model; import java.util.Map; import java.util.Set; import java.util.function.Supplier; +import io.micrometer.common.lang.NonNull; +import io.micrometer.common.lang.Nullable; + /** * Interface representing metadata associated with an AI model's response. * @@ -80,7 +81,7 @@ default T getOrDefault(String key, Supplier defaultObjectSupplier) { Set> entrySet(); - public Set keySet(); + Set keySet(); /** * Returns {@code true} if this map contains no key-value mappings. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java index 05d2aaca403..85f538d3bae 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java index 2c1de77a9b3..4b11f4fbb21 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; import reactor.core.publisher.Flux; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index cd5d43be105..8a2c84aca52 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Function; -import org.springframework.ai.chat.model.ToolContext; -import org.springframework.util.Assert; - import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.util.Assert; + /** * Abstract implementation of the {@link FunctionCallback} for interacting with the * Model's function calling protocol and a {@link Function} wrapping the interaction with @@ -102,7 +103,7 @@ public String getInputTypeSchema() { @Override public String call(String functionInput, ToolContext toolContext) { - I request = fromJson(functionInput, inputType); + I request = fromJson(functionInput, this.inputType); O response = this.apply(request, toolContext); return this.responseConverter.apply(response); } @@ -110,7 +111,7 @@ public String call(String functionInput, ToolContext toolContext) { @Override public String call(String functionArguments) { // Convert the tool calls JSON arguments into a Java function request object. - I request = fromJson(functionArguments, inputType); + I request = fromJson(functionArguments, this.inputType); // extend conversation with function response. return this.andThen(this.responseConverter).apply(request, null); } @@ -126,15 +127,17 @@ private T fromJson(String json, Class targetClass) { @Override public int hashCode() { - return Objects.hash(name, description, inputType); + return Objects.hash(this.name, this.description, this.inputType); } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null || getClass() != obj.getClass()) + } + if (obj == null || getClass() != obj.getClass()) { return false; + } AbstractFunctionCallback other = (AbstractFunctionCallback) obj; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java index dcad8414022..0e3946c7241 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import org.springframework.ai.chat.model.ToolContext; @@ -28,18 +29,18 @@ public interface FunctionCallback { /** * @return Returns the Function name. Unique within the model. */ - public String getName(); + String getName(); /** * @return Returns the function description. This description is used by the model do * decide if the function should be called or not. */ - public String getDescription(); + String getDescription(); /** * @return Returns the JSON schema of the function input type. */ - public String getInputTypeSchema(); + String getInputTypeSchema(); /** * Called when a model detects and triggers a function call. The model is responsible @@ -49,7 +50,7 @@ public interface FunctionCallback { * model. * @return String containing the function call response. */ - public String call(String functionInput); + String call(String functionInput); /** * Called when a model detects and triggers a function call. The model is responsible @@ -72,4 +73,4 @@ default String call(String functionInput, ToolContext tooContext) { return call(functionInput); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index ef06b80a0e3..ecbb9a4c1c8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.lang.reflect.Type; import java.util.function.BiFunction; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonClassDescription; + import org.springframework.ai.chat.model.ToolContext; import org.springframework.beans.BeansException; import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; @@ -31,8 +34,6 @@ import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; -import com.fasterxml.jackson.annotation.JsonClassDescription; - /** * A Spring {@link ApplicationContextAware} implementation that provides a way to retrieve * a {@link Function} from the Spring context and wrap it into a {@link FunctionCallback}. @@ -53,12 +54,6 @@ public class FunctionCallbackContext implements ApplicationContextAware { private GenericApplicationContext applicationContext; - public enum SchemaType { - - JSON_SCHEMA, OPEN_API_SCHEMA - - } - private SchemaType schemaType = SchemaType.JSON_SCHEMA; public void setSchemaType(SchemaType schemaType) { @@ -94,7 +89,8 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable if (!StringUtils.hasText(functionDescription)) { // Look for a Description annotation on the bean - Description descriptionAnnotation = applicationContext.findAnnotationOnBean(beanName, Description.class); + Description descriptionAnnotation = this.applicationContext.findAnnotationOnBean(beanName, + Description.class); if (descriptionAnnotation != null) { functionDescription = descriptionAnnotation.value(); @@ -139,4 +135,10 @@ else if (bean instanceof BiFunction biFunction) { } } + public enum SchemaType { + + JSON_SCHEMA, OPEN_API_SCHEMA + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index 960292cf8e3..fe9fa0a1533 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,33 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.function.BiFunction; import java.util.function.Function; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.json.JsonMapper; + import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.util.JacksonUtils; import org.springframework.util.Assert; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializationFeature; -import com.fasterxml.jackson.databind.json.JsonMapper; - /** * Note that the underlying function is responsible for converting the output into format * that can be consumed by the Model. The default implementation converts the output into * String before sending it to the Model. Provide a custom function responseConverter * implementation to override this. - * + * * @author Christian Tzolov * @author Sebastien Deleuze * */ -public class FunctionCallbackWrapper extends AbstractFunctionCallback { +public final class FunctionCallbackWrapper extends AbstractFunctionCallback { private final BiFunction biFunction; @@ -50,11 +51,6 @@ private FunctionCallbackWrapper(String name, String description, String inputTyp this.biFunction = function; } - @Override - public O apply(I input, ToolContext context) { - return this.biFunction.apply(input, context); - } - public static Builder builder(BiFunction biFunction) { return new Builder<>(biFunction); } @@ -63,19 +59,31 @@ public static Builder builder(Function function) { return new Builder<>(function); } + @Override + public O apply(I input, ToolContext context) { + return this.biFunction.apply(input, context); + } + public static class Builder { + private final BiFunction biFunction; + + private final Function function; + private String name; private String description; private Class inputType; - private final BiFunction biFunction; + private SchemaType schemaType = SchemaType.JSON_SCHEMA; - private final Function function; + // By default the response is converted to a JSON string. + private Function responseConverter = ModelOptionsUtils::toJsonString; - private SchemaType schemaType = SchemaType.JSON_SCHEMA; + private String inputTypeSchema; + + private ObjectMapper objectMapper; public Builder(BiFunction biFunction) { Assert.notNull(biFunction, "Function must not be null"); @@ -89,12 +97,16 @@ public Builder(Function function) { this.function = function; } - // By default the response is converted to a JSON string. - private Function responseConverter = ModelOptionsUtils::toJsonString; - - private String inputTypeSchema; + @SuppressWarnings("unchecked") + private static Class resolveInputType(BiFunction biFunction) { + return (Class) TypeResolverHelper + .getBiFunctionInputClass((Class>) biFunction.getClass()); + } - private ObjectMapper objectMapper; + @SuppressWarnings("unchecked") + private static Class resolveInputType(Function function) { + return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); + } public Builder withName(String name) { Assert.hasText(name, "Name must not be empty"); @@ -173,17 +185,6 @@ public FunctionCallbackWrapper build() { this.responseConverter, this.objectMapper, finalBiFunction); } - @SuppressWarnings("unchecked") - private static Class resolveInputType(BiFunction biFunction) { - return (Class) TypeResolverHelper - .getBiFunctionInputClass((Class>) biFunction.getClass()); - } - - @SuppressWarnings("unchecked") - private static Class resolveInputType(Function function) { - return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); - } - } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index f6189799303..722d7f24f6a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.List; @@ -26,6 +27,14 @@ */ public interface FunctionCallingOptions extends ChatOptions { + /** + * @return Returns FunctionCallingOptionsBuilder to create a new instance of + * FunctionCallingOptions. + */ + static FunctionCallingOptionsBuilder builder() { + return new FunctionCallingOptionsBuilder(); + } + /** * Function Callbacks to be registered with the ChatModel. For Prompt Options the * functionCallbacks are automatically enabled for the duration of the prompt @@ -67,16 +76,8 @@ default void setProxyToolCalls(Boolean proxyToolCalls) { } } - /** - * @return Returns FunctionCallingOptionsBuilder to create a new instance of - * FunctionCallingOptions. - */ - public static FunctionCallingOptionsBuilder builder() { - return new FunctionCallingOptionsBuilder(); - } - Map getToolContext(); void setToolContext(Map tooContext); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index ce84c8b048c..b5304270eca 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.ArrayList; @@ -185,7 +186,7 @@ public void setFunctions(Set functions) { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -194,7 +195,7 @@ public void setModel(String model) { @Override public Double getFrequencyPenalty() { - return frequencyPenalty; + return this.frequencyPenalty; } public void setFrequencyPenalty(Double frequencyPenalty) { @@ -203,7 +204,7 @@ public void setFrequencyPenalty(Double frequencyPenalty) { @Override public Integer getMaxTokens() { - return maxTokens; + return this.maxTokens; } public void setMaxTokens(Integer maxTokens) { @@ -212,7 +213,7 @@ public void setMaxTokens(Integer maxTokens) { @Override public Double getPresencePenalty() { - return presencePenalty; + return this.presencePenalty; } public void setPresencePenalty(Double presencePenalty) { @@ -221,7 +222,7 @@ public void setPresencePenalty(Double presencePenalty) { @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } public void setStopSequences(List stopSequences) { @@ -230,7 +231,7 @@ public void setStopSequences(List stopSequences) { @Override public Double getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Double temperature) { @@ -239,7 +240,7 @@ public void setTemperature(Double temperature) { @Override public Integer getTopK() { - return topK; + return this.topK; } public void setTopK(Integer topK) { @@ -248,7 +249,7 @@ public void setTopK(Integer topK) { @Override public Double getTopP() { - return topP; + return this.topP; } public void setTopP(Double topP) { @@ -257,7 +258,7 @@ public void setTopP(Double topP) { @Override public Boolean getProxyToolCalls() { - return proxyToolCalls; + return this.proxyToolCalls; } public void setProxyToolCalls(Boolean proxyToolCalls) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java index f3f23868d5c..4df569657b4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model.function; import java.util.ArrayList; @@ -7,6 +23,8 @@ import java.util.Set; import java.util.function.Function; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -19,8 +37,6 @@ import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; - /** * Helper class that reuses the {@link AbstractToolCallSupport} to implement the function * call handling logic on the client side. Used when the withProxyToolCalls(true) option @@ -28,36 +44,6 @@ */ public class ToolCallHelper extends AbstractToolCallSupport { - /** - * Helper used to provide only the function definition, without the actual function - * call implementation. - */ - public static record FunctionDefinition(String name, String description, - String inputTypeSchema) implements FunctionCallback { - - @Override - public String getName() { - return this.name(); - } - - @Override - public String getDescription() { - return this.description(); - } - - @Override - public String getInputTypeSchema() { - return this.inputTypeSchema(); - } - - @Override - public String call(String functionInput) { - throw new UnsupportedOperationException( - "FunctionDefinition provides only metadata. It doesn't implement the call method."); - } - - } - public ToolCallHelper() { this(null, PortableFunctionCallingOptions.builder().build(), List.of()); } @@ -163,4 +149,34 @@ public ChatResponse processCall(ChatModel chatModel, Prompt prompt, Set return processCall(chatModel, prompt2, finishReasons, customFunction); } -} \ No newline at end of file + /** + * Helper used to provide only the function definition, without the actual function + * call implementation. + */ + public static record FunctionDefinition(String name, String description, + String inputTypeSchema) implements FunctionCallback { + + @Override + public String getName() { + return this.name(); + } + + @Override + public String getDescription() { + return this.description(); + } + + @Override + public String getInputTypeSchema() { + return this.inputTypeSchema(); + } + + @Override + public String call(String functionInput) { + throw new UnsupportedOperationException( + "FunctionDefinition provides only metadata. It doesn't implement the call method."); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index ae6176b78f5..8ff8584c4bf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.lang.reflect.GenericArrayType; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java index 21b92e75b33..ff9e0a738e1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java @@ -1,25 +1,24 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model.observation; import java.util.List; import java.util.function.Consumer; -import org.springframework.util.Assert; - import io.micrometer.observation.Observation; import io.micrometer.observation.Observation.Context; import io.micrometer.observation.ObservationHandler; @@ -28,6 +27,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + /** * @author Christian Tzolov * @since 1.0.0 diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelObservationContext.java index 0c0ac67671d..931fdf976ac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelObservationContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.observation; import io.micrometer.observation.Observation; + import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.lang.Nullable; import org.springframework.util.Assert; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java index dfd6e6c84eb..4a5eb8eaf71 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.observation; +import java.util.ArrayList; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import io.micrometer.observation.Observation; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; import org.springframework.ai.observation.conventions.AiObservationMetricNames; import org.springframework.ai.observation.conventions.AiTokenType; -import java.util.ArrayList; -import java.util.List; - /** * Generate metrics about the model usage in the context of an AI operation. * @@ -38,6 +40,9 @@ public final class ModelUsageMetricsGenerator { private static final String DESCRIPTION = "Measures number of input and output tokens used"; + private ModelUsageMetricsGenerator() { + } + public static void generate(Usage usage, Observation.Context context, MeterRegistry meterRegistry) { if (usage.getPromptTokens() != null) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/package-info.java index 1d581773688..867e8507c76 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,4 +19,4 @@ package org.springframework.ai.model.observation; import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; \ No newline at end of file +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java index 57c2b34db7f..207af410e2f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Provides a set of interfaces and classes for a generic API designed to interact with * various AI models. This package includes interfaces for handling AI model calls, @@ -23,4 +24,5 @@ * ensuring a broad applicability across diverse AI scenarios. * */ -package org.springframework.ai.model; \ No newline at end of file + +package org.springframework.ai.model; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java index 3a170028e5f..1d0be3c39ef 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moderation; import java.util.Objects; @@ -10,29 +26,29 @@ * @author Ahmed Yousri * @since 1.0.0 */ -public class Categories { +public final class Categories { - private boolean sexual; + private final boolean sexual; - private boolean hate; + private final boolean hate; - private boolean harassment; + private final boolean harassment; - private boolean selfHarm; + private final boolean selfHarm; - private boolean sexualMinors; + private final boolean sexualMinors; - private boolean hateThreatening; + private final boolean hateThreatening; - private boolean violenceGraphic; + private final boolean violenceGraphic; - private boolean selfHarmIntent; + private final boolean selfHarmIntent; - private boolean selfHarmInstructions; + private final boolean selfHarmInstructions; - private boolean harassmentThreatening; + private final boolean harassmentThreatening; - private boolean violence; + private final boolean violence; private Categories(Builder builder) { this.sexual = builder.sexual; @@ -48,52 +64,84 @@ private Categories(Builder builder) { this.violence = builder.violence; } + public static Builder builder() { + return new Builder(); + } + public boolean isSexual() { - return sexual; + return this.sexual; } public boolean isHate() { - return hate; + return this.hate; } public boolean isHarassment() { - return harassment; + return this.harassment; } public boolean isSelfHarm() { - return selfHarm; + return this.selfHarm; } public boolean isSexualMinors() { - return sexualMinors; + return this.sexualMinors; } public boolean isHateThreatening() { - return hateThreatening; + return this.hateThreatening; } public boolean isViolenceGraphic() { - return violenceGraphic; + return this.violenceGraphic; } public boolean isSelfHarmIntent() { - return selfHarmIntent; + return this.selfHarmIntent; } public boolean isSelfHarmInstructions() { - return selfHarmInstructions; + return this.selfHarmInstructions; } public boolean isHarassmentThreatening() { - return harassmentThreatening; + return this.harassmentThreatening; } public boolean isViolence() { - return violence; + return this.violence; } - public static Builder builder() { - return new Builder(); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Categories)) { + return false; + } + Categories that = (Categories) o; + return this.sexual == that.sexual && this.hate == that.hate && this.harassment == that.harassment + && this.selfHarm == that.selfHarm && this.sexualMinors == that.sexualMinors + && this.hateThreatening == that.hateThreatening && this.violenceGraphic == that.violenceGraphic + && this.selfHarmIntent == that.selfHarmIntent && this.selfHarmInstructions == that.selfHarmInstructions + && this.harassmentThreatening == that.harassmentThreatening && this.violence == that.violence; + } + + @Override + public int hashCode() { + return Objects.hash(this.sexual, this.hate, this.harassment, this.selfHarm, this.sexualMinors, + this.hateThreatening, this.violenceGraphic, this.selfHarmIntent, this.selfHarmInstructions, + this.harassmentThreatening, this.violence); + } + + @Override + public String toString() { + return "Categories{" + "sexual=" + this.sexual + ", hate=" + this.hate + ", harassment=" + this.harassment + + ", selfHarm=" + this.selfHarm + ", sexualMinors=" + this.sexualMinors + ", hateThreatening=" + + this.hateThreatening + ", violenceGraphic=" + this.violenceGraphic + ", selfHarmIntent=" + + this.selfHarmIntent + ", selfHarmInstructions=" + this.selfHarmInstructions + + ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + '}'; } public static class Builder { @@ -181,33 +229,4 @@ public Categories build() { } - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof Categories)) - return false; - Categories that = (Categories) o; - return sexual == that.sexual && hate == that.hate && harassment == that.harassment && selfHarm == that.selfHarm - && sexualMinors == that.sexualMinors && hateThreatening == that.hateThreatening - && violenceGraphic == that.violenceGraphic && selfHarmIntent == that.selfHarmIntent - && selfHarmInstructions == that.selfHarmInstructions - && harassmentThreatening == that.harassmentThreatening && violence == that.violence; - } - - @Override - public int hashCode() { - return Objects.hash(sexual, hate, harassment, selfHarm, sexualMinors, hateThreatening, violenceGraphic, - selfHarmIntent, selfHarmInstructions, harassmentThreatening, violence); - } - - @Override - public String toString() { - return "Categories{" + "sexual=" + sexual + ", hate=" + hate + ", harassment=" + harassment + ", selfHarm=" - + selfHarm + ", sexualMinors=" + sexualMinors + ", hateThreatening=" + hateThreatening - + ", violenceGraphic=" + violenceGraphic + ", selfHarmIntent=" + selfHarmIntent - + ", selfHarmInstructions=" + selfHarmInstructions + ", harassmentThreatening=" + harassmentThreatening - + ", violence=" + violence + '}'; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java index 8429b783472..c96dc4e2b7e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moderation; import java.util.Objects; @@ -10,29 +26,29 @@ * @author Ahmed Yousri * @since 1.0.0 */ -public class CategoryScores { +public final class CategoryScores { - private double sexual; + private final double sexual; - private double hate; + private final double hate; - private double harassment; + private final double harassment; - private double selfHarm; + private final double selfHarm; - private double sexualMinors; + private final double sexualMinors; - private double hateThreatening; + private final double hateThreatening; - private double violenceGraphic; + private final double violenceGraphic; - private double selfHarmIntent; + private final double selfHarmIntent; - private double selfHarmInstructions; + private final double selfHarmInstructions; - private double harassmentThreatening; + private final double harassmentThreatening; - private double violence; + private final double violence; private CategoryScores(Builder builder) { this.sexual = builder.sexual; @@ -48,52 +64,89 @@ private CategoryScores(Builder builder) { this.violence = builder.violence; } + public static Builder builder() { + return new Builder(); + } + public double getSexual() { - return sexual; + return this.sexual; } public double getHate() { - return hate; + return this.hate; } public double getHarassment() { - return harassment; + return this.harassment; } public double getSelfHarm() { - return selfHarm; + return this.selfHarm; } public double getSexualMinors() { - return sexualMinors; + return this.sexualMinors; } public double getHateThreatening() { - return hateThreatening; + return this.hateThreatening; } public double getViolenceGraphic() { - return violenceGraphic; + return this.violenceGraphic; } public double getSelfHarmIntent() { - return selfHarmIntent; + return this.selfHarmIntent; } public double getSelfHarmInstructions() { - return selfHarmInstructions; + return this.selfHarmInstructions; } public double getHarassmentThreatening() { - return harassmentThreatening; + return this.harassmentThreatening; } public double getViolence() { - return violence; + return this.violence; } - public static Builder builder() { - return new Builder(); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CategoryScores)) { + return false; + } + CategoryScores that = (CategoryScores) o; + return Double.compare(that.sexual, this.sexual) == 0 && Double.compare(that.hate, this.hate) == 0 + && Double.compare(that.harassment, this.harassment) == 0 + && Double.compare(that.selfHarm, this.selfHarm) == 0 + && Double.compare(that.sexualMinors, this.sexualMinors) == 0 + && Double.compare(that.hateThreatening, this.hateThreatening) == 0 + && Double.compare(that.violenceGraphic, this.violenceGraphic) == 0 + && Double.compare(that.selfHarmIntent, this.selfHarmIntent) == 0 + && Double.compare(that.selfHarmInstructions, this.selfHarmInstructions) == 0 + && Double.compare(that.harassmentThreatening, this.harassmentThreatening) == 0 + && Double.compare(that.violence, this.violence) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(this.sexual, this.hate, this.harassment, this.selfHarm, this.sexualMinors, + this.hateThreatening, this.violenceGraphic, this.selfHarmIntent, this.selfHarmInstructions, + this.harassmentThreatening, this.violence); + } + + @Override + public String toString() { + return "CategoryScores{" + "sexual=" + this.sexual + ", hate=" + this.hate + ", harassment=" + this.harassment + + ", selfHarm=" + this.selfHarm + ", sexualMinors=" + this.sexualMinors + ", hateThreatening=" + + this.hateThreatening + ", violenceGraphic=" + this.violenceGraphic + ", selfHarmIntent=" + + this.selfHarmIntent + ", selfHarmInstructions=" + this.selfHarmInstructions + + ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + '}'; } public static class Builder { @@ -181,37 +234,4 @@ public CategoryScores build() { } - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof CategoryScores)) - return false; - CategoryScores that = (CategoryScores) o; - return Double.compare(that.sexual, sexual) == 0 && Double.compare(that.hate, hate) == 0 - && Double.compare(that.harassment, harassment) == 0 && Double.compare(that.selfHarm, selfHarm) == 0 - && Double.compare(that.sexualMinors, sexualMinors) == 0 - && Double.compare(that.hateThreatening, hateThreatening) == 0 - && Double.compare(that.violenceGraphic, violenceGraphic) == 0 - && Double.compare(that.selfHarmIntent, selfHarmIntent) == 0 - && Double.compare(that.selfHarmInstructions, selfHarmInstructions) == 0 - && Double.compare(that.harassmentThreatening, harassmentThreatening) == 0 - && Double.compare(that.violence, violence) == 0; - } - - @Override - public int hashCode() { - return Objects.hash(sexual, hate, harassment, selfHarm, sexualMinors, hateThreatening, violenceGraphic, - selfHarmIntent, selfHarmInstructions, harassmentThreatening, violence); - } - - @Override - public String toString() { - return "CategoryScores{" + "sexual=" + sexual + ", hate=" + hate + ", harassment=" + harassment + ", selfHarm=" - + selfHarm + ", sexualMinors=" + sexualMinors + ", hateThreatening=" + hateThreatening - + ", violenceGraphic=" + violenceGraphic + ", selfHarmIntent=" + selfHarmIntent - + ", selfHarmInstructions=" + selfHarmInstructions + ", harassmentThreatening=" + harassmentThreatening - + ", violence=" + violence + '}'; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Generation.java index 98a4cf5fd43..e73ebb2325c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Generation.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,18 +52,18 @@ public Generation withGenerationMetadata(@Nullable ModerationGenerationMetadata @Override public Moderation getOutput() { - return moderation; + return this.moderation; } @Override public ModerationGenerationMetadata getMetadata() { - return moderationGenerationMetadata; + return this.moderationGenerationMetadata; } @Override public String toString() { - return "Generation{" + "moderationGenerationMetadata=" + moderationGenerationMetadata + ", moderation=" - + moderation + '}'; + return "Generation{" + "moderationGenerationMetadata=" + this.moderationGenerationMetadata + ", moderation=" + + this.moderation + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Moderation.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Moderation.java index a98b94c727b..7fe43a94882 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Moderation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Moderation.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moderation; import java.util.Arrays; @@ -12,7 +28,7 @@ * @author Ahmed Yousri * @since 1.0.0 */ -public class Moderation { +public final class Moderation { private final String id; @@ -26,42 +42,44 @@ private Moderation(Builder builder) { this.results = builder.moderationResultList; } + public static Builder builder() { + return new Builder(); + } + public String getId() { - return id; + return this.id; } public String getModel() { - return model; + return this.model; } public List getResults() { - return results; + return this.results; } @Override public String toString() { - return "Moderation{" + "id='" + id + '\'' + ", model='" + model + '\'' + ", results=" - + Arrays.toString(results.toArray()) + '}'; + return "Moderation{" + "id='" + this.id + '\'' + ", model='" + this.model + '\'' + ", results=" + + Arrays.toString(this.results.toArray()) + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Moderation)) + } + if (!(o instanceof Moderation)) { return false; + } Moderation that = (Moderation) o; - return Objects.equals(id, that.id) && Objects.equals(model, that.model) - && Objects.equals(results, that.results); + return Objects.equals(this.id, that.id) && Objects.equals(this.model, that.model) + && Objects.equals(this.results, that.results); } @Override public int hashCode() { - return Objects.hash(id, model, results); - } - - public static Builder builder() { - return new Builder(); + return Objects.hash(this.id, this.model, this.results); } public static class Builder { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationGenerationMetadata.java index f186ec54d71..5cb66e4ad21 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationGenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationGenerationMetadata.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationMessage.java index 455dd695ca9..335f9f5a16e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationMessage.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ public ModerationMessage(String text) { } public String getText() { - return text; + return this.text; } public void setText(String text) { @@ -44,22 +44,24 @@ public void setText(String text) { @Override public String toString() { - return "ModerationMessage{" + "text='" + text + '\'' + '}'; + return "ModerationMessage{" + "text='" + this.text + '\'' + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ModerationMessage)) + } + if (!(o instanceof ModerationMessage)) { return false; + } ModerationMessage that = (ModerationMessage) o; - return Objects.equals(text, that.text); + return Objects.equals(this.text, that.text); } @Override public int hashCode() { - return Objects.hash(text); + return Objects.hash(this.text); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationModel.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationModel.java index d7d47bf6707..188fce42a7a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptions.java index 57ac68f4353..238989f07a0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptionsBuilder.java index edacf2cbae8..b476d33c231 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptionsBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,22 +25,7 @@ * @author Ahmed Yousri * @since 1.0.0 */ -public class ModerationOptionsBuilder { - - private class ModerationModelOptionsImpl implements ModerationOptions { - - private String model; - - public void setModel(String model) { - this.model = model; - } - - @Override - public String getModel() { - return model; - } - - } +public final class ModerationOptionsBuilder { private final ModerationModelOptionsImpl options = new ModerationModelOptionsImpl(); @@ -53,12 +38,27 @@ public static ModerationOptionsBuilder builder() { } public ModerationOptionsBuilder withModel(String model) { - options.setModel(model); + this.options.setModel(model); return this; } public ModerationOptions build() { - return options; + return this.options; + } + + private class ModerationModelOptionsImpl implements ModerationOptions { + + private String model; + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationPrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationPrompt.java index e783cb84f25..02514d4b9af 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationPrompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationPrompt.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,10 @@ package org.springframework.ai.moderation; -import org.springframework.ai.model.ModelRequest; import java.util.Objects; +import org.springframework.ai.model.ModelRequest; + /** * Represents a prompt for moderation containing a single message and the options for the * moderation model. This class offers constructors to create a prompt from a single @@ -50,11 +51,11 @@ public ModerationPrompt(String instructions) { @Override public ModerationMessage getInstructions() { - return message; + return this.message; } public ModerationOptions getOptions() { - return moderationModelOptions; + return this.moderationModelOptions; } public void setOptions(ModerationOptions moderationModelOptions) { @@ -63,23 +64,26 @@ public void setOptions(ModerationOptions moderationModelOptions) { @Override public String toString() { - return "ModerationPrompt{" + "message=" + message + ", moderationModelOptions=" + moderationModelOptions + '}'; + return "ModerationPrompt{" + "message=" + this.message + ", moderationModelOptions=" + + this.moderationModelOptions + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ModerationPrompt)) + } + if (!(o instanceof ModerationPrompt)) { return false; + } ModerationPrompt that = (ModerationPrompt) o; - return Objects.equals(message, that.message) - && Objects.equals(moderationModelOptions, that.moderationModelOptions); + return Objects.equals(this.message, that.message) + && Objects.equals(this.moderationModelOptions, that.moderationModelOptions); } @Override public int hashCode() { - return Objects.hash(message, moderationModelOptions); + return Objects.hash(this.message, this.moderationModelOptions); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponse.java index 5da1469f2a4..043104436e1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,11 @@ package org.springframework.ai.moderation; -import org.springframework.ai.model.ModelResponse; - import java.util.List; import java.util.Objects; +import org.springframework.ai.model.ModelResponse; + /** * Represents a response from a moderation process, encapsulating the moderation metadata * and the generated content. This class provides access to both the single generation @@ -48,38 +48,40 @@ public ModerationResponse(Generation generations, ModerationResponseMetadata mod @Override public Generation getResult() { - return generations; + return this.generations; } @Override public List getResults() { - return List.of(generations); + return List.of(this.generations); } @Override public ModerationResponseMetadata getMetadata() { - return moderationResponseMetadata; + return this.moderationResponseMetadata; } @Override public String toString() { - return "ModerationResponse{" + "moderationResponseMetadata=" + moderationResponseMetadata + ", generations=" - + generations + '}'; + return "ModerationResponse{" + "moderationResponseMetadata=" + this.moderationResponseMetadata + + ", generations=" + this.generations + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ModerationResponse that)) + } + if (!(o instanceof ModerationResponse that)) { return false; - return Objects.equals(moderationResponseMetadata, that.moderationResponseMetadata) - && Objects.equals(generations, that.generations); + } + return Objects.equals(this.moderationResponseMetadata, that.moderationResponseMetadata) + && Objects.equals(this.generations, that.generations); } @Override public int hashCode() { - return Objects.hash(moderationResponseMetadata, generations); + return Objects.hash(this.moderationResponseMetadata, this.generations); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponseMetadata.java index 785d598c7e2..c32804dea3a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponseMetadata.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResult.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResult.java index d7ec33e5d99..29ea1083075 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResult.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResult.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moderation; import java.util.Objects; @@ -10,7 +26,7 @@ * @author Ahmed Yousri * @since 1.0.0 */ -public class ModerationResult { +public final class ModerationResult { private boolean flagged; @@ -24,8 +40,12 @@ private ModerationResult(Builder builder) { this.categoryScores = builder.categoryScores; } + public static Builder builder() { + return new Builder(); + } + public boolean isFlagged() { - return flagged; + return this.flagged; } public void setFlagged(boolean flagged) { @@ -33,7 +53,7 @@ public void setFlagged(boolean flagged) { } public Categories getCategories() { - return categories; + return this.categories; } public void setCategories(Categories categories) { @@ -41,15 +61,35 @@ public void setCategories(Categories categories) { } public CategoryScores getCategoryScores() { - return categoryScores; + return this.categoryScores; } public void setCategoryScores(CategoryScores categoryScores) { this.categoryScores = categoryScores; } - public static Builder builder() { - return new Builder(); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ModerationResult)) { + return false; + } + ModerationResult that = (ModerationResult) o; + return this.flagged == that.flagged && Objects.equals(this.categories, that.categories) + && Objects.equals(this.categoryScores, that.categoryScores); + } + + @Override + public int hashCode() { + return Objects.hash(this.flagged, this.categories, this.categoryScores); + } + + @Override + public String toString() { + return "ModerationResult{" + "flagged=" + this.flagged + ", categories=" + this.categories + ", categoryScores=" + + this.categoryScores + '}'; } public static class Builder { @@ -81,26 +121,4 @@ public ModerationResult build() { } - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof ModerationResult)) - return false; - ModerationResult that = (ModerationResult) o; - return flagged == that.flagged && Objects.equals(categories, that.categories) - && Objects.equals(categoryScores, that.categoryScores); - } - - @Override - public int hashCode() { - return Objects.hash(flagged, categories, categoryScores); - } - - @Override - public String toString() { - return "ModerationResult{" + "flagged=" + flagged + ", categories=" + categories + ", categoryScores=" - + categoryScores + '}'; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/AiOperationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/AiOperationMetadata.java index 68b1c3dfffe..a8707b1e6bf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/AiOperationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/AiOperationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation; import org.springframework.ai.observation.conventions.AiOperationType; @@ -41,7 +42,7 @@ public static Builder builder() { return new Builder(); } - public static class Builder { + public static final class Builder { private String operationType; @@ -61,7 +62,7 @@ public Builder provider(String provider) { } public AiOperationMetadata build() { - return new AiOperationMetadata(operationType, provider); + return new AiOperationMetadata(this.operationType, this.provider); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationAttributes.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationAttributes.java index 29566449380..eea71318b25 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationAttributes.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationAttributes.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -141,7 +142,7 @@ public enum AiObservationAttributes { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationEventNames.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationEventNames.java index c3f86f353a1..a44ce86ff28 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationEventNames.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationEventNames.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -39,7 +40,7 @@ public enum AiObservationEventNames { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricAttributes.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricAttributes.java index e8d828e7d62..b729a8e5b37 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricAttributes.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricAttributes.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -41,7 +42,7 @@ public enum AiObservationMetricAttributes { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricNames.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricNames.java index 3587553196c..fb8ca023a92 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricNames.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricNames.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -39,7 +40,7 @@ public enum AiObservationMetricNames { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java index 85fa4f2a26a..3defa442f86 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index 88d6a5aaf8d..63e2403c001 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiTokenType.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiTokenType.java index a8c2fec383c..013731f94ff 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiTokenType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiTokenType.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -40,7 +41,7 @@ public enum AiTokenType { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/SpringAiKind.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/SpringAiKind.java index 11c70f45952..d23861d5191 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/SpringAiKind.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/SpringAiKind.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationAttributes.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationAttributes.java index bd869f0cbef..e6f02cdb919 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationAttributes.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationAttributes.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -109,7 +110,7 @@ public enum VectorStoreObservationAttributes { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationEventNames.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationEventNames.java index 9dc843e13aa..589b7ffcdb6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationEventNames.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationEventNames.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -34,7 +35,7 @@ public enum VectorStoreObservationEventNames { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java index 2ceaf2f54f5..bd518c982cc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreSimilarityMetric.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreSimilarityMetric.java index ec60701a469..9a1ac00ca69 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreSimilarityMetric.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreSimilarityMetric.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/package-info.java index 53f533019b5..34a4401d0a1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,4 +19,4 @@ package org.springframework.ai.observation.conventions; import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; \ No newline at end of file +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/package-info.java index 1ef4dfd3240..023d5dc6987 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,4 +19,4 @@ package org.springframework.ai.observation; import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; \ No newline at end of file +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/tracing/TracingHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/tracing/TracingHelper.java index fce9a93e841..a675fd98ed4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/tracing/TracingHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/tracing/TracingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.observation.tracing; -import io.micrometer.tracing.handler.TracingObservationHandler; -import io.opentelemetry.api.trace.Span; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.observation.ChatModelObservationContext; -import org.springframework.ai.model.Content; -import org.springframework.lang.Nullable; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; +package org.springframework.ai.observation.tracing; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -31,6 +22,13 @@ import java.util.Map; import java.util.StringJoiner; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.opentelemetry.api.trace.Span; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.lang.Nullable; + /** * Utilities to prepare and process traces for observability. * @@ -40,6 +38,9 @@ public final class TracingHelper { private static final Logger logger = LoggerFactory.getLogger(TracingHelper.class); + private TracingHelper() { + } + @Nullable public static Span extractOtelSpan(@Nullable TracingObservationHandler.TracingContext tracingContext) { if (tracingContext == null) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/EmptyJsonMetadataGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/EmptyJsonMetadataGenerator.java index a56714aef78..9ba62979fd4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/EmptyJsonMetadataGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/EmptyJsonMetadataGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import java.util.Collections; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/ExtractedTextFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/ExtractedTextFormatter.java index 03669f1b747..31112672c4f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/ExtractedTextFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/ExtractedTextFormatter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import org.springframework.util.StringUtils; @@ -32,19 +33,19 @@ * * @author Christian Tzolov */ -public class ExtractedTextFormatter { +public final class ExtractedTextFormatter { /** Flag indicating if the text should be left-aligned */ - private boolean leftAlignment; + private final boolean leftAlignment; /** Number of top pages to skip before performing delete operations */ - private int numberOfTopPagesToSkipBeforeDelete; + private final int numberOfTopPagesToSkipBeforeDelete; /** Number of top text lines to delete from a page */ - private int numberOfTopTextLinesToDelete; + private final int numberOfTopTextLinesToDelete; /** Number of bottom text lines to delete from a page */ - private int numberOfBottomTextLinesToDelete; + private final int numberOfBottomTextLinesToDelete; /** * Private constructor to initialize the formatter from the builder. @@ -73,6 +74,82 @@ public static ExtractedTextFormatter defaults() { return new Builder().build(); } + /** + * Replaces multiple, adjacent blank lines into a single blank line. + * @param pageText text to adjust the blank lines for. + * @return Returns the same text but with blank lines trimmed. + */ + public static String trimAdjacentBlankLines(String pageText) { + return pageText.replaceAll("(?m)(^ *\n)", "\n").replaceAll("(?m)^$([\r\n]+?)(^$[\r\n]+?^)+", "$1"); + } + + /** + * @param pageText text to align. + * @return Returns the same text but aligned to the left side. + */ + public static String alignToLeft(String pageText) { + return pageText.replaceAll("(?m)(^ *| +(?= |$))", "").replaceAll("(?m)^$( ?)(^$[\r\n]+?^)+", "$1"); + } + + /** + * Removes the specified number of lines from the bottom part of the text. + * @param pageText Text to remove lines from. + * @param numberOfLines Number of lines to remove. + * @return Returns the text striped from last lines. + */ + public static String deleteBottomTextLines(String pageText, int numberOfLines) { + if (!StringUtils.hasText(pageText)) { + return pageText; + } + + int lineCount = 0; + int truncateIndex = pageText.length(); + int nextTruncateIndex = truncateIndex; + while (lineCount < numberOfLines && nextTruncateIndex >= 0) { + nextTruncateIndex = pageText.lastIndexOf(System.lineSeparator(), truncateIndex - 1); + truncateIndex = nextTruncateIndex < 0 ? truncateIndex : nextTruncateIndex; + lineCount++; + } + return pageText.substring(0, truncateIndex); + } + + /** + * Removes a specified number of lines from the top part of the given text. + * + *

+ * This method takes a text and trims it by removing a certain number of lines from + * the top. If the provided text is null or contains only whitespace, it will be + * returned as is. If the number of lines to remove exceeds the actual number of lines + * in the text, the result will be an empty string. + *

+ * + *

+ * The method identifies lines based on the system's line separator, making it + * compatible with different platforms. + *

+ * @param pageText The text from which the top lines need to be removed. If this is + * null, empty, or consists only of whitespace, it will be returned unchanged. + * @param numberOfLines The number of lines to remove from the top of the text. If + * this exceeds the actual number of lines in the text, an empty string will be + * returned. + * @return The text with the specified number of lines removed from the top. + */ + public static String deleteTopTextLines(String pageText, int numberOfLines) { + if (!StringUtils.hasText(pageText)) { + return pageText; + } + int lineCount = 0; + + int truncateIndex = 0; + int nextTruncateIndex = truncateIndex; + while (lineCount < numberOfLines && nextTruncateIndex >= 0) { + nextTruncateIndex = pageText.indexOf(System.lineSeparator(), truncateIndex + 1); + truncateIndex = nextTruncateIndex < 0 ? truncateIndex : nextTruncateIndex; + lineCount++; + } + return pageText.substring(truncateIndex); + } + /** * Formats the provided text according to the formatter's configuration. * @param pageText Text to be formatted. @@ -126,7 +203,7 @@ public String format(String pageText, int pageNumber) { *
  • Number of top text lines to delete to 0
  • *
  • Number of bottom text lines to delete to 0
  • * - * + * * *

    * After configuring the builder, calling the {@link #build()} method will return a @@ -209,80 +286,4 @@ public ExtractedTextFormatter build() { } - /** - * Replaces multiple, adjacent blank lines into a single blank line. - * @param pageText text to adjust the blank lines for. - * @return Returns the same text but with blank lines trimmed. - */ - public static String trimAdjacentBlankLines(String pageText) { - return pageText.replaceAll("(?m)(^ *\n)", "\n").replaceAll("(?m)^$([\r\n]+?)(^$[\r\n]+?^)+", "$1"); - } - - /** - * @param pageText text to align. - * @return Returns the same text but aligned to the left side. - */ - public static String alignToLeft(String pageText) { - return pageText.replaceAll("(?m)(^ *| +(?= |$))", "").replaceAll("(?m)^$( ?)(^$[\r\n]+?^)+", "$1"); - } - - /** - * Removes the specified number of lines from the bottom part of the text. - * @param pageText Text to remove lines from. - * @param numberOfLines Number of lines to remove. - * @return Returns the text striped from last lines. - */ - public static String deleteBottomTextLines(String pageText, int numberOfLines) { - if (!StringUtils.hasText(pageText)) { - return pageText; - } - - int lineCount = 0; - int truncateIndex = pageText.length(); - int nextTruncateIndex = truncateIndex; - while (lineCount < numberOfLines && nextTruncateIndex >= 0) { - nextTruncateIndex = pageText.lastIndexOf(System.lineSeparator(), truncateIndex - 1); - truncateIndex = nextTruncateIndex < 0 ? truncateIndex : nextTruncateIndex; - lineCount++; - } - return pageText.substring(0, truncateIndex); - } - - /** - * Removes a specified number of lines from the top part of the given text. - * - *

    - * This method takes a text and trims it by removing a certain number of lines from - * the top. If the provided text is null or contains only whitespace, it will be - * returned as is. If the number of lines to remove exceeds the actual number of lines - * in the text, the result will be an empty string. - *

    - * - *

    - * The method identifies lines based on the system's line separator, making it - * compatible with different platforms. - *

    - * @param pageText The text from which the top lines need to be removed. If this is - * null, empty, or consists only of whitespace, it will be returned unchanged. - * @param numberOfLines The number of lines to remove from the top of the text. If - * this exceeds the actual number of lines in the text, an empty string will be - * returned. - * @return The text with the specified number of lines removed from the top. - */ - public static String deleteTopTextLines(String pageText, int numberOfLines) { - if (!StringUtils.hasText(pageText)) { - return pageText; - } - int lineCount = 0; - - int truncateIndex = 0; - int nextTruncateIndex = truncateIndex; - while (lineCount < numberOfLines && nextTruncateIndex >= 0) { - nextTruncateIndex = pageText.indexOf(System.lineSeparator(), truncateIndex + 1); - truncateIndex = nextTruncateIndex < 0 ? truncateIndex : nextTruncateIndex; - lineCount++; - } - return pageText.substring(truncateIndex, pageText.length()); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonMetadataGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonMetadataGenerator.java index 4a4ffb1e9fc..a556e8b65dc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonMetadataGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonMetadataGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import java.util.Map; @@ -23,6 +24,8 @@ public interface JsonMetadataGenerator { /** * The input is the JSON document represented as a map, the output are the fields * extracted from the input map that will be used as metadata. + * @param jsonMap json document map + * @return json metadata map */ Map generate(Map jsonMap); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java index 7b4a2e8cc71..b827a4e18c2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import java.io.IOException; -import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Collections; import java.util.stream.StreamSupport; import com.fasterxml.jackson.core.type.TypeReference; @@ -37,7 +37,7 @@ * * @author Mark Pollack * @author Christian Tzolov - * @author rivkode + * @author rivkode rivkode * @since 1.0.0 */ public class JsonReader implements DocumentReader { @@ -73,15 +73,15 @@ public JsonReader(Resource resource, JsonMetadataGenerator jsonMetadataGenerator @Override public List get() { try { - JsonNode rootNode = objectMapper.readTree(this.resource.getInputStream()); + JsonNode rootNode = this.objectMapper.readTree(this.resource.getInputStream()); if (rootNode.isArray()) { return StreamSupport.stream(rootNode.spliterator(), true) - .map(jsonNode -> parseJsonNode(jsonNode, objectMapper)) + .map(jsonNode -> parseJsonNode(jsonNode, this.objectMapper)) .toList(); } else { - return Collections.singletonList(parseJsonNode(rootNode, objectMapper)); + return Collections.singletonList(parseJsonNode(rootNode, this.objectMapper)); } } catch (IOException e) { @@ -91,10 +91,11 @@ public List get() { private Document parseJsonNode(JsonNode jsonNode, ObjectMapper objectMapper) { Map item = objectMapper.convertValue(jsonNode, new TypeReference>() { + }); var sb = new StringBuilder(); - jsonKeysToUse.stream().filter(item::containsKey).forEach(key -> { + this.jsonKeysToUse.stream().filter(item::containsKey).forEach(key -> { sb.append(key).append(": ").append(item.get(key)).append(System.lineSeparator()); }); @@ -106,11 +107,11 @@ private Document parseJsonNode(JsonNode jsonNode, ObjectMapper objectMapper) { protected List get(JsonNode rootNode) { if (rootNode.isArray()) { return StreamSupport.stream(rootNode.spliterator(), true) - .map(jsonNode -> parseJsonNode(jsonNode, objectMapper)) + .map(jsonNode -> parseJsonNode(jsonNode, this.objectMapper)) .toList(); } else { - return Collections.singletonList(parseJsonNode(rootNode, objectMapper)); + return Collections.singletonList(parseJsonNode(rootNode, this.objectMapper)); } } @@ -122,7 +123,7 @@ protected List get(JsonNode rootNode) { */ public List get(String pointer) { try { - JsonNode rootNode = objectMapper.readTree(this.resource.getInputStream()); + JsonNode rootNode = this.objectMapper.readTree(this.resource.getInputStream()); JsonNode targetNode = rootNode.at(pointer); if (targetNode.isMissingNode()) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/TextReader.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/TextReader.java index 55bb2e2bf70..db9e284d048 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/TextReader.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/TextReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import java.io.IOException; @@ -46,13 +47,13 @@ public class TextReader implements DocumentReader { */ private final Resource resource; + private final Map customMetadata = new HashMap<>(); + /** * Character set to be used when loading data from the */ private Charset charset = StandardCharsets.UTF_8; - private final Map customMetadata = new HashMap<>(); - public TextReader(String resourceUrl) { this(new DefaultResourceLoader().getResource(resourceUrl)); } @@ -62,15 +63,15 @@ public TextReader(Resource resource) { this.resource = resource; } + public Charset getCharset() { + return this.charset; + } + public void setCharset(Charset charset) { Objects.requireNonNull(charset, "The charset must not be null"); this.charset = charset; } - public Charset getCharset() { - return this.charset; - } - /** * Metadata associated with all documents created by the loader. * @return Metadata to be assigned to the output Documents. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java b/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java index 8a1dc60aa01..760a9a4b345 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -87,4 +87,4 @@ public int estimate(Iterable contents) { return totalSize; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/TokenCountEstimator.java b/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/TokenCountEstimator.java index 03a9eff5aa3..e33c464e9dd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/TokenCountEstimator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/TokenCountEstimator.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,4 +48,4 @@ public interface TokenCountEstimator { */ int estimate(Iterable messages); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/ContentFormatTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/ContentFormatTransformer.java index 880abc73526..32e201402ac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/ContentFormatTransformer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/ContentFormatTransformer.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer; import java.util.ArrayList; @@ -67,7 +68,7 @@ public ContentFormatTransformer(ContentFormatter contentFormatter, boolean disab * @return processed documents */ public List apply(List documents) { - if (contentFormatter != null) { + if (this.contentFormatter != null) { documents.forEach(this::processDocument); } @@ -76,7 +77,7 @@ public List apply(List documents) { private void processDocument(Document document) { if (document.getContentFormatter() instanceof DefaultContentFormatter docFormatter - && contentFormatter instanceof DefaultContentFormatter toUpdateFormatter) { + && this.contentFormatter instanceof DefaultContentFormatter toUpdateFormatter) { updateFormatter(document, docFormatter, toUpdateFormatter); } @@ -99,7 +100,7 @@ private void updateFormatter(Document document, DefaultContentFormatter docForma .withMetadataTemplate(docFormatter.getMetadataTemplate()) .withMetadataSeparator(docFormatter.getMetadataSeparator()); - if (!disableTemplateRewrite) { + if (!this.disableTemplateRewrite) { builder.withTextTemplate(docFormatter.getTextTemplate()); } @@ -107,7 +108,7 @@ private void updateFormatter(Document document, DefaultContentFormatter docForma } private void overrideFormatter(Document document) { - document.setContentFormatter(contentFormatter); + document.setContentFormatter(this.contentFormatter); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java index 10db9364d16..dd02b336bb4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer; import java.util.List; import java.util.Map; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.document.Document; -import org.springframework.ai.document.DocumentTransformer; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentTransformer; import org.springframework.util.Assert; /** @@ -32,14 +33,14 @@ */ public class KeywordMetadataEnricher implements DocumentTransformer { - private static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords"; - public static final String CONTEXT_STR_PLACEHOLDER = "context_str"; public static final String KEYWORDS_TEMPLATE = """ {context_str}. Give %s unique keywords for this document. Format as comma separated. Keywords: """; + private static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords"; + /** * Model predictor */ @@ -62,7 +63,7 @@ public KeywordMetadataEnricher(ChatModel chatModel, int keywordCount) { public List apply(List documents) { for (Document document : documents) { - var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount)); + var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, this.keywordCount)); Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getContent())); String keywords = this.chatModel.call(prompt).getResult().getOutput().getContent(); document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords)); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java index 60ecf450b3c..a1537828285 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer; import java.util.ArrayList; @@ -21,11 +22,11 @@ import java.util.Map; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentTransformer; import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -38,14 +39,6 @@ */ public class SummaryMetadataEnricher implements DocumentTransformer { - private static final String SECTION_SUMMARY_METADATA_KEY = "section_summary"; - - private static final String NEXT_SECTION_SUMMARY_METADATA_KEY = "next_section_summary"; - - private static final String PREV_SECTION_SUMMARY_METADATA_KEY = "prev_section_summary"; - - private static final String CONTEXT_STR_PLACEHOLDER = "context_str"; - public static final String DEFAULT_SUMMARY_EXTRACT_TEMPLATE = """ Here is the content of the section: {context_str} @@ -54,11 +47,13 @@ public class SummaryMetadataEnricher implements DocumentTransformer { Summary:"""; - public enum SummaryType { + private static final String SECTION_SUMMARY_METADATA_KEY = "section_summary"; - PREVIOUS, CURRENT, NEXT + private static final String NEXT_SECTION_SUMMARY_METADATA_KEY = "next_section_summary"; - } + private static final String PREV_SECTION_SUMMARY_METADATA_KEY = "prev_section_summary"; + + private static final String CONTEXT_STR_PLACEHOLDER = "context_str"; /** * AI client. @@ -127,4 +122,10 @@ private Map getSummaryMetadata(int i, List documentSumma return summaryMetadata; } + public enum SummaryType { + + PREVIOUS, CURRENT, NEXT + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TextSplitter.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TextSplitter.java index 809fc556b8f..7d5439c8653 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TextSplitter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TextSplitter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer.splitter; import java.util.ArrayList; @@ -22,6 +23,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.ContentFormatter; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentTransformer; @@ -49,14 +51,14 @@ public List split(Document document) { return this.apply(List.of(document)); } - public void setCopyContentFormatter(boolean copyContentFormatter) { - this.copyContentFormatter = copyContentFormatter; - } - public boolean isCopyContentFormatter() { return this.copyContentFormatter; } + public void setCopyContentFormatter(boolean copyContentFormatter) { + this.copyContentFormatter = copyContentFormatter; + } + private List doSplitDocuments(List documents) { List texts = new ArrayList<>(); List> metadataList = new ArrayList<>(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java index 420e7a28772..4c129543624 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer.splitter; import java.util.ArrayList; @@ -33,10 +34,6 @@ */ public class TokenTextSplitter extends TextSplitter { - private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry(); - - private final Encoding encoding = registry.getEncoding(EncodingType.CL100K_BASE); - private final static int DEFAULT_CHUNK_SIZE = 800; private final static int MIN_CHUNK_SIZE_CHARS = 350; @@ -47,6 +44,10 @@ public class TokenTextSplitter extends TextSplitter { private final static boolean KEEP_SEPARATOR = true; + private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry(); + + private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE); + // The target size of each text chunk in tokens private final int chunkSize; @@ -78,6 +79,10 @@ public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengt this.keepSeparator = keepSeparator; } + public static Builder builder() { + return new Builder(); + } + @Override protected List splitText(String text) { return doSplit(text, this.chunkSize); @@ -145,11 +150,7 @@ private String decodeTokens(List tokens) { return this.encoding.decode(tokensIntArray); } - public static Builder builder() { - return new Builder(); - } - - public static class Builder { + public static final class Builder { private int chunkSize; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/JacksonUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/util/JacksonUtils.java index 80176631b30..3686dd417ca 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/util/JacksonUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/JacksonUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.util; import java.util.ArrayList; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/ParsingUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/util/ParsingUtils.java index 52801625642..591aa0801ae 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/util/ParsingUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/ParsingUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2014-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.util; import java.util.ArrayList; @@ -35,8 +36,8 @@ public abstract class ParsingUtils { private static final String LOWER = "\\p{Ll}"; - private static final String CAMEL_CASE_REGEX = "(? split(String source, boolean toLower) { return Collections.unmodifiableList(result); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java index b430793cdd0..a009e56f78b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Objects; + import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; import org.springframework.util.Assert; -import java.util.Objects; - /** * Similarity search request builder. Use the {@link #query(String)}, {@link #defaults()} * or {@link #from(SearchRequest)} factory methods to create a new {@link SearchRequest} @@ -30,7 +31,7 @@ * * @author Christian Tzolov */ -public class SearchRequest { +public final class SearchRequest { /** * Similarity threshold that accepts all search scores. A threshold value of 0.0 means @@ -231,19 +232,19 @@ public SearchRequest withFilterExpression(String textExpression) { } public String getQuery() { - return query; + return this.query; } public int getTopK() { - return topK; + return this.topK; } public double getSimilarityThreshold() { - return similarityThreshold; + return this.similarityThreshold; } public Filter.Expression getFilterExpression() { - return filterExpression; + return this.filterExpression; } public boolean hasFilterExpression() { @@ -252,24 +253,27 @@ public boolean hasFilterExpression() { @Override public String toString() { - return "SearchRequest{" + "query='" + query + '\'' + ", topK=" + topK + ", similarityThreshold=" - + similarityThreshold + ", filterExpression=" + filterExpression + '}'; + return "SearchRequest{" + "query='" + this.query + '\'' + ", topK=" + this.topK + ", similarityThreshold=" + + this.similarityThreshold + ", filterExpression=" + this.filterExpression + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } SearchRequest that = (SearchRequest) o; - return topK == that.topK && Double.compare(that.similarityThreshold, similarityThreshold) == 0 - && Objects.equals(query, that.query) && Objects.equals(filterExpression, that.filterExpression); + return this.topK == that.topK && Double.compare(that.similarityThreshold, this.similarityThreshold) == 0 + && Objects.equals(this.query, that.query) + && Objects.equals(this.filterExpression, that.filterExpression); } @Override public int hashCode() { - return Objects.hash(query, topK, similarityThreshold, filterExpression); + return Objects.hash(this.query, this.topK, this.similarityThreshold, this.filterExpression); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java index 724c0395d55..f2e558eb6a8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.File; @@ -32,9 +33,15 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.databind.json.JsonMapper; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -45,13 +52,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.core.io.Resource; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.ObjectWriter; - -import io.micrometer.observation.ObservationRegistry; - /** * SimpleVectorStore is a simple implementation of the VectorStore interface. * @@ -177,6 +177,7 @@ public void save(File file) { */ public void load(File file) { TypeReference> typeRef = new TypeReference<>() { + }; try { Map deserializedMap = this.objectMapper.readValue(file, typeRef); @@ -193,6 +194,7 @@ public void load(File file) { */ public void load(Resource resource) { TypeReference> typeRef = new TypeReference<>() { + }; try { Map deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef); @@ -219,6 +221,15 @@ private float[] getUserQueryEmbedding(String query) { return this.embeddingModel.embed(query); } + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.SIMPLE.value(), operationName) + .withDimensions(this.embeddingModel.dimensions()) + .withCollectionName("in-memory-map") + .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); + } + public static class Similarity { private String key; @@ -232,7 +243,7 @@ public Similarity(String key, double score) { } - public class EmbeddingMath { + public final class EmbeddingMath { private EmbeddingMath() { throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); @@ -276,13 +287,4 @@ public static float norm(float[] vector) { } - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - - return VectorStoreObservationContext.builder(VectorStoreProvider.SIMPLE.value(), operationName) - .withDimensions(this.embeddingModel.dimensions()) - .withCollectionName("in-memory-map") - .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java index dadcdda3b18..1e8cf62dcde 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java index e3036c075de..53e16c691d4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; /** @@ -64,6 +65,23 @@ */ public class Filter { + /** + * Filter expression operations.
    + * + * - EQ, NE, GT, GTE, LT, LTE operations supports "Key ExprType Value" + * expressions.
    + * + * - AND, OR are binary operations that support "(Expression|Group) ExprType + * (Expression|Group)" expressions.
    + * + * - IN, NIN support "Key (IN|NIN) ArrayValue" expression.
    + */ + public enum ExpressionType { + + AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN, NOT + + } + /** * Mark interface representing the supported expression types: {@link Key}, * {@link Value}, {@link Expression} and {@link Group}. @@ -79,6 +97,7 @@ public interface Operand { * @param key expression key */ public record Key(String key) implements Operand { + } /** @@ -88,22 +107,6 @@ public record Key(String key) implements Operand { * @param value value constant or constant array */ public record Value(Object value) implements Operand { - } - - /** - * Filter expression operations.
    - * - * - EQ, NE, GT, GTE, LT, LTE operations supports "Key ExprType Value" - * expressions.
    - * - * - AND, OR are binary operations that support "(Expression|Group) ExprType - * (Expression|Group)" expressions.
    - * - * - IN, NIN support "Key (IN|NIN) ArrayValue" expression.
    - */ - public enum ExpressionType { - - AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN, NOT } @@ -120,9 +123,11 @@ public enum ExpressionType { * be another {@link Expression}. */ public record Expression(ExpressionType type, Operand left, Operand right) implements Operand { + public Expression(ExpressionType type, Operand operand) { this(type, operand, null); } + } /** @@ -132,6 +137,7 @@ public Expression(ExpressionType type, Operand operand) { * @param content Inner expression to be evaluated as a part of the group. */ public record Group(Expression content) implements Operand { + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java index d3913c1d424..f7410c898c7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -55,20 +56,6 @@ */ public class FilterExpressionBuilder { - public record Op(Filter.Operand expression) { - - public Filter.Expression build() { - if (expression instanceof Filter.Group group) { - // Remove the top-level grouping. - return group.content(); - } - else if (expression instanceof Filter.Expression exp) { - return exp; - } - throw new RuntimeException("Invalid expression: " + expression); - } - } - public Op eq(String key, Object value) { return new Op(new Filter.Expression(ExpressionType.EQ, new Key(key), new Value(value))); } @@ -125,4 +112,19 @@ public Op not(Op content) { return new Op(new Filter.Expression(ExpressionType.NOT, content.expression, null)); } + public record Op(Filter.Operand expression) { + + public Filter.Expression build() { + if (this.expression instanceof Filter.Group group) { + // Remove the top-level grouping. + return group.content(); + } + else if (this.expression instanceof Filter.Expression exp) { + return exp; + } + throw new RuntimeException("Invalid expression: " + this.expression); + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionConverter.java index 127f2dc9270..463da47fdd9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; /** @@ -24,6 +25,6 @@ */ public interface FilterExpressionConverter { - public String convertExpression(Filter.Expression expression); + String convertExpression(Filter.Expression expression); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java index c4fd1a9d62e..7429420bc36 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.ArrayList; @@ -162,7 +163,7 @@ public void clearCache() { /** For testing only */ Map getCache() { - return cache; + return this.cache; } public static class FilterExpressionParseException extends RuntimeException { @@ -301,4 +302,4 @@ public void syntaxError(Recognizer recognizer, Object offendingSymbol, int } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java index 555d2ca8756..ce2bebf9118 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.ArrayList; @@ -29,10 +30,7 @@ * * @author Christian Tzolov */ -public class FilterHelper { - - private FilterHelper() { - } +public final class FilterHelper { private final static Map TYPE_NEGATION_MAP = Map.of(ExpressionType.AND, ExpressionType.OR, ExpressionType.OR, ExpressionType.AND, ExpressionType.EQ, ExpressionType.NE, @@ -40,6 +38,9 @@ private FilterHelper() { ExpressionType.LT, ExpressionType.LT, ExpressionType.GTE, ExpressionType.LTE, ExpressionType.GT, ExpressionType.IN, ExpressionType.NIN, ExpressionType.NIN, ExpressionType.IN); + private FilterHelper() { + } + /** * Transforms the input expression into a semantically equivalent one with negation * operators propagated thought the expression tree by following the negation rules: diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java index 65bca7a3a1c..136aedbab1b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + package org.springframework.ai.vectorstore.filter.antlr4; -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # @@ -422,4 +408,4 @@ public void visitTerminal(TerminalNode node) { public void visitErrorNode(ErrorNode node) { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java index 99b0de1b689..f8a5a204199 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + package org.springframework.ai.vectorstore.filter.antlr4; -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # @@ -244,4 +230,4 @@ public T visitBooleanConstant(FiltersParser.BooleanConstantContext ctx) { return visitChildren(ctx); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersLexer.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersLexer.java index 87fa8abfdd6..cf4e31c59fc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersLexer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersLexer.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,147 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + package org.springframework.ai.vectorstore.filter.antlr4; -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ -import org.antlr.v4.runtime.Lexer; import org.antlr.v4.runtime.CharStream; -import org.antlr.v4.runtime.*; -import org.antlr.v4.runtime.atn.*; +import org.antlr.v4.runtime.Lexer; +import org.antlr.v4.runtime.RuntimeMetaData; +import org.antlr.v4.runtime.Vocabulary; +import org.antlr.v4.runtime.VocabularyImpl; +import org.antlr.v4.runtime.atn.ATN; +import org.antlr.v4.runtime.atn.ATNDeserializer; +import org.antlr.v4.runtime.atn.LexerATNSimulator; +import org.antlr.v4.runtime.atn.PredictionContextCache; import org.antlr.v4.runtime.dfa.DFA; @SuppressWarnings({ "all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue", "this-escape" }) public class FiltersLexer extends Lexer { - static { - RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); - } - - protected static final DFA[] _decisionToDFA; - - protected static final PredictionContextCache _sharedContextCache = new PredictionContextCache(); - public static final int WHERE = 1, DOT = 2, COMMA = 3, LEFT_SQUARE_BRACKETS = 4, RIGHT_SQUARE_BRACKETS = 5, LEFT_PARENTHESIS = 6, RIGHT_PARENTHESIS = 7, EQUALS = 8, MINUS = 9, PLUS = 10, GT = 11, GE = 12, LT = 13, LE = 14, NE = 15, AND = 16, OR = 17, IN = 18, NIN = 19, NOT = 20, BOOLEAN_VALUE = 21, QUOTED_STRING = 22, INTEGER_VALUE = 23, DECIMAL_VALUE = 24, IDENTIFIER = 25, WS = 26; - public static String[] channelNames = { "DEFAULT_TOKEN_CHANNEL", "HIDDEN" }; - - public static String[] modeNames = { "DEFAULT_MODE" }; - - private static String[] makeRuleNames() { - return new String[] { "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", - "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", - "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", - "IDENTIFIER", "DECIMAL_DIGITS", "DIGIT", "LETTER", "WS" }; - } - public static final String[] ruleNames = makeRuleNames(); - private static String[] makeLiteralNames() { - return new String[] { null, null, "'.'", "','", "'['", "']'", "'('", "')'", "'=='", "'-'", "'+'", "'>'", "'>='", - "'<'", "'<='", "'!='" }; - } - - private static final String[] _LITERAL_NAMES = makeLiteralNames(); - - private static String[] makeSymbolicNames() { - return new String[] { null, "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", - "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", - "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", - "IDENTIFIER", "WS" }; - } - - private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames(); - - public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES); - /** * @deprecated Use {@link #VOCABULARY} instead. */ @Deprecated public static final String[] tokenNames; - static { - tokenNames = new String[_SYMBOLIC_NAMES.length]; - for (int i = 0; i < tokenNames.length; i++) { - tokenNames[i] = VOCABULARY.getLiteralName(i); - if (tokenNames[i] == null) { - tokenNames[i] = VOCABULARY.getSymbolicName(i); - } - - if (tokenNames[i] == null) { - tokenNames[i] = ""; - } - } - } - - @Override - @Deprecated - public String[] getTokenNames() { - return tokenNames; - } - - @Override - - public Vocabulary getVocabulary() { - return VOCABULARY; - } - - public FiltersLexer(CharStream input) { - super(input); - _interp = new LexerATNSimulator(this, _ATN, _decisionToDFA, _sharedContextCache); - } - - @Override - public String getGrammarFileName() { - return "Filters.g4"; - } - - @Override - public String[] getRuleNames() { - return ruleNames; - } - - @Override - public String getSerializedATN() { - return _serializedATN; - } - - @Override - public String[] getChannelNames() { - return channelNames; - } - - @Override - public String[] getModeNames() { - return modeNames; - } - - @Override - public ATN getATN() { - return _ATN; - } public static final String _serializedATN = "\u0004\u0000\u001a\u00e5\u0006\uffff\uffff\u0002\u0000\u0007\u0000\u0002" + "\u0001\u0007\u0001\u0002\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002" @@ -307,6 +201,105 @@ public ATN getATN() { + "\u00cf\u00d6\u00d8\u00e1\u0001\u0000\u0001\u0000"; public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); + + protected static final DFA[] _decisionToDFA; + + protected static final PredictionContextCache _sharedContextCache = new PredictionContextCache(); + + private static final String[] _LITERAL_NAMES = makeLiteralNames(); + + private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames(); + + public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES); + + public static String[] channelNames = { "DEFAULT_TOKEN_CHANNEL", "HIDDEN" }; + + public static String[] modeNames = { "DEFAULT_MODE" }; + + public FiltersLexer(CharStream input) { + super(input); + _interp = new LexerATNSimulator(this, _ATN, _decisionToDFA, _sharedContextCache); + } + + private static String[] makeRuleNames() { + return new String[] { "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", + "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", + "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", + "IDENTIFIER", "DECIMAL_DIGITS", "DIGIT", "LETTER", "WS" }; + } + + private static String[] makeLiteralNames() { + return new String[] { null, null, "'.'", "','", "'['", "']'", "'('", "')'", "'=='", "'-'", "'+'", "'>'", "'>='", + "'<'", "'<='", "'!='" }; + } + + private static String[] makeSymbolicNames() { + return new String[] { null, "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", + "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", + "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", + "IDENTIFIER", "WS" }; + } + + @Override + @Deprecated + public String[] getTokenNames() { + return tokenNames; + } + + @Override + + public Vocabulary getVocabulary() { + return VOCABULARY; + } + + @Override + public String getGrammarFileName() { + return "Filters.g4"; + } + + @Override + public String[] getRuleNames() { + return ruleNames; + } + + @Override + public String getSerializedATN() { + return _serializedATN; + } + + @Override + public String[] getChannelNames() { + return channelNames; + } + + @Override + public String[] getModeNames() { + return modeNames; + } + + @Override + public ATN getATN() { + return _ATN; + } + + static { + RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); + } + + static { + tokenNames = new String[_SYMBOLIC_NAMES.length]; + for (int i = 0; i < tokenNames.length; i++) { + tokenNames[i] = VOCABULARY.getLiteralName(i); + if (tokenNames[i] == null) { + tokenNames[i] = VOCABULARY.getSymbolicName(i); + } + + if (tokenNames[i] == null) { + tokenNames[i] = ""; + } + } + } + static { _decisionToDFA = new DFA[_ATN.getNumberOfDecisions()]; for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) { @@ -314,4 +307,4 @@ public ATN getATN() { } } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java index c16c841b7dc..8e49aeff6b3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + package org.springframework.ai.vectorstore.filter.antlr4; -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # @@ -246,4 +232,4 @@ public interface FiltersListener extends ParseTreeListener { */ void exitBooleanConstant(FiltersParser.BooleanConstantContext ctx); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java index ab66cff95a0..945a3a95334 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,46 +13,40 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + package org.springframework.ai.vectorstore.filter.antlr4; -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ -import org.antlr.v4.runtime.atn.*; -import org.antlr.v4.runtime.dfa.DFA; -import org.antlr.v4.runtime.*; -import org.antlr.v4.runtime.tree.*; import java.util.List; +import org.antlr.v4.runtime.FailedPredicateException; +import org.antlr.v4.runtime.NoViableAltException; +import org.antlr.v4.runtime.Parser; +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.RecognitionException; +import org.antlr.v4.runtime.RuleContext; +import org.antlr.v4.runtime.RuntimeMetaData; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.TokenStream; +import org.antlr.v4.runtime.Vocabulary; +import org.antlr.v4.runtime.VocabularyImpl; +import org.antlr.v4.runtime.atn.ATN; +import org.antlr.v4.runtime.atn.ATNDeserializer; +import org.antlr.v4.runtime.atn.ParserATNSimulator; +import org.antlr.v4.runtime.atn.PredictionContextCache; +import org.antlr.v4.runtime.dfa.DFA; +import org.antlr.v4.runtime.tree.ParseTreeListener; +import org.antlr.v4.runtime.tree.ParseTreeVisitor; +import org.antlr.v4.runtime.tree.TerminalNode; + @SuppressWarnings({ "all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue" }) public class FiltersParser extends Parser { - static { - RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); - } - - protected static final DFA[] _decisionToDFA; - - protected static final PredictionContextCache _sharedContextCache = new PredictionContextCache(); - public static final int WHERE = 1, DOT = 2, COMMA = 3, LEFT_SQUARE_BRACKETS = 4, RIGHT_SQUARE_BRACKETS = 5, LEFT_PARENTHESIS = 6, RIGHT_PARENTHESIS = 7, EQUALS = 8, MINUS = 9, PLUS = 10, GT = 11, GE = 12, LT = 13, LE = 14, NE = 15, AND = 16, OR = 17, IN = 18, NIN = 19, NOT = 20, BOOLEAN_VALUE = 21, QUOTED_STRING = 22, @@ -61,35 +55,86 @@ public class FiltersParser extends Parser { public static final int RULE_where = 0, RULE_booleanExpression = 1, RULE_constantArray = 2, RULE_compare = 3, RULE_identifier = 4, RULE_constant = 5; - private static String[] makeRuleNames() { - return new String[] { "where", "booleanExpression", "constantArray", "compare", "identifier", "constant" }; - } - public static final String[] ruleNames = makeRuleNames(); - private static String[] makeLiteralNames() { - return new String[] { null, null, "'.'", "','", "'['", "']'", "'('", "')'", "'=='", "'-'", "'+'", "'>'", "'>='", - "'<'", "'<='", "'!='" }; - } + /** + * @deprecated Use {@link #VOCABULARY} instead. + */ + @Deprecated + public static final String[] tokenNames; - private static final String[] _LITERAL_NAMES = makeLiteralNames(); + public static final String _serializedATN = "\u0004\u0001\u001aY\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002" + + "\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004\u0007\u0004\u0002" + + "\u0005\u0007\u0005\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0003\u0001\u001e\b\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0003\u0001(\b" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0005\u00010\b\u0001\n\u0001\f\u00013\t\u0001\u0001\u0002\u0001" + + "\u0002\u0001\u0002\u0001\u0002\u0005\u00029\b\u0002\n\u0002\f\u0002<\t" + + "\u0002\u0001\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001" + + "\u0004\u0001\u0004\u0001\u0004\u0001\u0004\u0003\u0004G\b\u0004\u0001" + + "\u0005\u0003\u0005J\b\u0005\u0001\u0005\u0001\u0005\u0003\u0005N\b\u0005" + + "\u0001\u0005\u0001\u0005\u0004\u0005R\b\u0005\u000b\u0005\f\u0005S\u0001" + + "\u0005\u0003\u0005W\b\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000" + + "\u0002\u0004\u0006\b\n\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000" + + "\t\nb\u0000\f\u0001\u0000\u0000\u0000\u0002\'\u0001\u0000\u0000\u0000" + + "\u00044\u0001\u0000\u0000\u0000\u0006?\u0001\u0000\u0000\u0000\bF\u0001" + + "\u0000\u0000\u0000\nV\u0001\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000" + + "\r\u000e\u0003\u0002\u0001\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f" + + "\u0001\u0001\u0000\u0000\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000" + + "\u0011\u0012\u0003\b\u0004\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013" + + "\u0014\u0003\n\u0005\u0000\u0014(\u0001\u0000\u0000\u0000\u0015\u0016" + + "\u0003\b\u0004\u0000\u0016\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003" + + "\u0004\u0002\u0000\u0018(\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b" + + "\u0004\u0000\u001a\u001b\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012" + + "\u0000\u0000\u001c\u001e\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000" + + "\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000" + + "\u0000\u0000\u001f \u0003\u0004\u0002\u0000 (\u0001\u0000\u0000\u0000" + + "!\"\u0005\u0006\u0000\u0000\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000" + + "\u0000$(\u0001\u0000\u0000\u0000%&\u0005\u0014\u0000\u0000&(\u0003\u0002" + + "\u0001\u0001\'\u0010\u0001\u0000\u0000\u0000\'\u0015\u0001\u0000\u0000" + + "\u0000\'\u0019\u0001\u0000\u0000\u0000\'!\u0001\u0000\u0000\u0000\'%\u0001" + + "\u0000\u0000\u0000(1\u0001\u0000\u0000\u0000)*\n\u0004\u0000\u0000*+\u0005" + + "\u0010\u0000\u0000+0\u0003\u0002\u0001\u0005,-\n\u0003\u0000\u0000-.\u0005" + + "\u0011\u0000\u0000.0\u0003\u0002\u0001\u0004/)\u0001\u0000\u0000\u0000" + + "/,\u0001\u0000\u0000\u000003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000" + + "\u000012\u0001\u0000\u0000\u00002\u0003\u0001\u0000\u0000\u000031\u0001" + + "\u0000\u0000\u000045\u0005\u0004\u0000\u00005:\u0003\n\u0005\u000067\u0005" + + "\u0003\u0000\u000079\u0003\n\u0005\u000086\u0001\u0000\u0000\u00009<\u0001" + + "\u0000\u0000\u0000:8\u0001\u0000\u0000\u0000:;\u0001\u0000\u0000\u0000" + + ";=\u0001\u0000\u0000\u0000<:\u0001\u0000\u0000\u0000=>\u0005\u0005\u0000" + + "\u0000>\u0005\u0001\u0000\u0000\u0000?@\u0007\u0000\u0000\u0000@\u0007" + + "\u0001\u0000\u0000\u0000AB\u0005\u0019\u0000\u0000BC\u0005\u0002\u0000" + + "\u0000CG\u0005\u0019\u0000\u0000DG\u0005\u0019\u0000\u0000EG\u0005\u0016" + + "\u0000\u0000FA\u0001\u0000\u0000\u0000FD\u0001\u0000\u0000\u0000FE\u0001" + + "\u0000\u0000\u0000G\t\u0001\u0000\u0000\u0000HJ\u0007\u0001\u0000\u0000" + + "IH\u0001\u0000\u0000\u0000IJ\u0001\u0000\u0000\u0000JK\u0001\u0000\u0000" + + "\u0000KW\u0005\u0017\u0000\u0000LN\u0007\u0001\u0000\u0000ML\u0001\u0000" + + "\u0000\u0000MN\u0001\u0000\u0000\u0000NO\u0001\u0000\u0000\u0000OW\u0005" + + "\u0018\u0000\u0000PR\u0005\u0016\u0000\u0000QP\u0001\u0000\u0000\u0000" + + "RS\u0001\u0000\u0000\u0000SQ\u0001\u0000\u0000\u0000ST\u0001\u0000\u0000" + + "\u0000TW\u0001\u0000\u0000\u0000UW\u0005\u0015\u0000\u0000VI\u0001\u0000" + + "\u0000\u0000VM\u0001\u0000\u0000\u0000VQ\u0001\u0000\u0000\u0000VU\u0001" + + "\u0000\u0000\u0000W\u000b\u0001\u0000\u0000\u0000\n\u001d\'/1:FIMSV"; - private static String[] makeSymbolicNames() { - return new String[] { null, "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", - "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", - "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", - "IDENTIFIER", "WS" }; - } + public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); + + protected static final DFA[] _decisionToDFA; + + protected static final PredictionContextCache _sharedContextCache = new PredictionContextCache(); + + private static final String[] _LITERAL_NAMES = makeLiteralNames(); private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames(); public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES); - /** - * @deprecated Use {@link #VOCABULARY} instead. - */ - @Deprecated - public static final String[] tokenNames; + static { + RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); + } + static { tokenNames = new String[_SYMBOLIC_NAMES.length]; for (int i = 0; i < tokenNames.length; i++) { @@ -104,6 +149,34 @@ private static String[] makeSymbolicNames() { } } + static { + _decisionToDFA = new DFA[_ATN.getNumberOfDecisions()]; + for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) { + _decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i); + } + } + + public FiltersParser(TokenStream input) { + super(input); + _interp = new ParserATNSimulator(this, _ATN, _decisionToDFA, _sharedContextCache); + } + + private static String[] makeRuleNames() { + return new String[] { "where", "booleanExpression", "constantArray", "compare", "identifier", "constant" }; + } + + private static String[] makeLiteralNames() { + return new String[] { null, null, "'.'", "','", "'['", "']'", "'('", "')'", "'=='", "'-'", "'+'", "'>'", "'>='", + "'<'", "'<='", "'!='" }; + } + + private static String[] makeSymbolicNames() { + return new String[] { null, "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", + "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", + "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", + "IDENTIFIER", "WS" }; + } + @Override @Deprecated public String[] getTokenNames() { @@ -136,57 +209,6 @@ public ATN getATN() { return _ATN; } - public FiltersParser(TokenStream input) { - super(input); - _interp = new ParserATNSimulator(this, _ATN, _decisionToDFA, _sharedContextCache); - } - - @SuppressWarnings("CheckReturnValue") - public static class WhereContext extends ParserRuleContext { - - public TerminalNode WHERE() { - return getToken(FiltersParser.WHERE, 0); - } - - public BooleanExpressionContext booleanExpression() { - return getRuleContext(BooleanExpressionContext.class, 0); - } - - public TerminalNode EOF() { - return getToken(FiltersParser.EOF, 0); - } - - public WhereContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - - @Override - public int getRuleIndex() { - return RULE_where; - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterWhere(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitWhere(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitWhere(this); - else - return visitor.visitChildren(this); - } - - } - public final WhereContext where() throws RecognitionException { WhereContext _localctx = new WhereContext(_ctx, getState()); enterRule(_localctx, 0, RULE_where); @@ -212,488 +234,815 @@ public final WhereContext where() throws RecognitionException { return _localctx; } - @SuppressWarnings("CheckReturnValue") - public static class BooleanExpressionContext extends ParserRuleContext { - - public BooleanExpressionContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - - @Override - public int getRuleIndex() { - return RULE_booleanExpression; - } - - public BooleanExpressionContext() { - } - - public void copyFrom(BooleanExpressionContext ctx) { - super.copyFrom(ctx); - } - + public final BooleanExpressionContext booleanExpression() throws RecognitionException { + return booleanExpression(0); } - @SuppressWarnings("CheckReturnValue") - public static class NinExpressionContext extends BooleanExpressionContext { - - public IdentifierContext identifier() { - return getRuleContext(IdentifierContext.class, 0); - } - - public ConstantArrayContext constantArray() { - return getRuleContext(ConstantArrayContext.class, 0); - } - - public TerminalNode NOT() { - return getToken(FiltersParser.NOT, 0); - } - - public TerminalNode IN() { - return getToken(FiltersParser.IN, 0); - } - - public TerminalNode NIN() { - return getToken(FiltersParser.NIN, 0); - } + private BooleanExpressionContext booleanExpression(int _p) throws RecognitionException { + ParserRuleContext _parentctx = _ctx; + int _parentState = getState(); + BooleanExpressionContext _localctx = new BooleanExpressionContext(_ctx, _parentState); + BooleanExpressionContext _prevctx = _localctx; + int _startState = 2; + enterRecursionRule(_localctx, 2, RULE_booleanExpression, _p); + try { + int _alt; + enterOuterAlt(_localctx, 1); + { + setState(39); + _errHandler.sync(this); + switch (getInterpreter().adaptivePredict(_input, 1, _ctx)) { + case 1: { + _localctx = new CompareExpressionContext(_localctx); + _ctx = _localctx; + _prevctx = _localctx; - public NinExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); + setState(17); + identifier(); + setState(18); + compare(); + setState(19); + constant(); + } + break; + case 2: { + _localctx = new InExpressionContext(_localctx); + _ctx = _localctx; + _prevctx = _localctx; + setState(21); + identifier(); + setState(22); + match(IN); + setState(23); + constantArray(); + } + break; + case 3: { + _localctx = new NinExpressionContext(_localctx); + _ctx = _localctx; + _prevctx = _localctx; + setState(25); + identifier(); + setState(29); + _errHandler.sync(this); + switch (_input.LA(1)) { + case NOT: { + setState(26); + match(NOT); + setState(27); + match(IN); + } + break; + case NIN: { + setState(28); + match(NIN); + } + break; + default: + throw new NoViableAltException(this); + } + setState(31); + constantArray(); + } + break; + case 4: { + _localctx = new GroupExpressionContext(_localctx); + _ctx = _localctx; + _prevctx = _localctx; + setState(33); + match(LEFT_PARENTHESIS); + setState(34); + booleanExpression(0); + setState(35); + match(RIGHT_PARENTHESIS); + } + break; + case 5: { + _localctx = new NotExpressionContext(_localctx); + _ctx = _localctx; + _prevctx = _localctx; + setState(37); + match(NOT); + setState(38); + booleanExpression(1); + } + break; + } + _ctx.stop = _input.LT(-1); + setState(49); + _errHandler.sync(this); + _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); + while (_alt != 2 && _alt != org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER) { + if (_alt == 1) { + if (_parseListeners != null) { + triggerExitRuleEvent(); + } + _prevctx = _localctx; + { + setState(47); + _errHandler.sync(this); + switch (getInterpreter().adaptivePredict(_input, 2, _ctx)) { + case 1: { + _localctx = new AndExpressionContext( + new BooleanExpressionContext(_parentctx, _parentState)); + ((AndExpressionContext) _localctx).left = _prevctx; + pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); + setState(41); + if (!(precpred(_ctx, 4))) { + throw new FailedPredicateException(this, "precpred(_ctx, 4)"); + } + setState(42); + ((AndExpressionContext) _localctx).operator = match(AND); + setState(43); + ((AndExpressionContext) _localctx).right = booleanExpression(5); + } + break; + case 2: { + _localctx = new OrExpressionContext( + new BooleanExpressionContext(_parentctx, _parentState)); + ((OrExpressionContext) _localctx).left = _prevctx; + pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); + setState(44); + if (!(precpred(_ctx, 3))) { + throw new FailedPredicateException(this, "precpred(_ctx, 3)"); + } + setState(45); + ((OrExpressionContext) _localctx).operator = match(OR); + setState(46); + ((OrExpressionContext) _localctx).right = booleanExpression(4); + } + break; + } + } + } + setState(51); + _errHandler.sync(this); + _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); + } + } } + catch (RecognitionException re) { + _localctx.exception = re; + _errHandler.reportError(this, re); + _errHandler.recover(this, re); + } + finally { + unrollRecursionContexts(_parentctx); + } + return _localctx; + } - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterNinExpression(this); + public final ConstantArrayContext constantArray() throws RecognitionException { + ConstantArrayContext _localctx = new ConstantArrayContext(_ctx, getState()); + enterRule(_localctx, 4, RULE_constantArray); + int _la; + try { + enterOuterAlt(_localctx, 1); + { + setState(52); + match(LEFT_SQUARE_BRACKETS); + setState(53); + constant(); + setState(58); + _errHandler.sync(this); + _la = _input.LA(1); + while (_la == COMMA) { + { + { + setState(54); + match(COMMA); + setState(55); + constant(); + } + } + setState(60); + _errHandler.sync(this); + _la = _input.LA(1); + } + setState(61); + match(RIGHT_SQUARE_BRACKETS); + } + } + catch (RecognitionException re) { + _localctx.exception = re; + _errHandler.reportError(this, re); + _errHandler.recover(this, re); + } + finally { + exitRule(); } + return _localctx; + } - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitNinExpression(this); + public final CompareContext compare() throws RecognitionException { + CompareContext _localctx = new CompareContext(_ctx, getState()); + enterRule(_localctx, 6, RULE_compare); + int _la; + try { + enterOuterAlt(_localctx, 1); + { + setState(63); + _la = _input.LA(1); + if (!((((_la) & ~0x3f) == 0 && ((1L << _la) & 63744L) != 0))) { + _errHandler.recoverInline(this); + } + else { + if (_input.LA(1) == Token.EOF) { + matchedEOF = true; + } + _errHandler.reportMatch(this); + consume(); + } + } + } + catch (RecognitionException re) { + _localctx.exception = re; + _errHandler.reportError(this, re); + _errHandler.recover(this, re); + } + finally { + exitRule(); } + return _localctx; + } - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitNinExpression(this); - else - return visitor.visitChildren(this); + public final IdentifierContext identifier() throws RecognitionException { + IdentifierContext _localctx = new IdentifierContext(_ctx, getState()); + enterRule(_localctx, 8, RULE_identifier); + try { + setState(70); + _errHandler.sync(this); + switch (getInterpreter().adaptivePredict(_input, 5, _ctx)) { + case 1: + enterOuterAlt(_localctx, 1); { + setState(65); + match(IDENTIFIER); + setState(66); + match(DOT); + setState(67); + match(IDENTIFIER); + } + break; + case 2: + enterOuterAlt(_localctx, 2); { + setState(68); + match(IDENTIFIER); + } + break; + case 3: + enterOuterAlt(_localctx, 3); { + setState(69); + match(QUOTED_STRING); + } + break; + } + } + catch (RecognitionException re) { + _localctx.exception = re; + _errHandler.reportError(this, re); + _errHandler.recover(this, re); + } + finally { + exitRule(); } + return _localctx; + } + + public final ConstantContext constant() throws RecognitionException { + ConstantContext _localctx = new ConstantContext(_ctx, getState()); + enterRule(_localctx, 10, RULE_constant); + int _la; + try { + int _alt; + setState(86); + _errHandler.sync(this); + switch (getInterpreter().adaptivePredict(_input, 9, _ctx)) { + case 1: + _localctx = new IntegerConstantContext(_localctx); + enterOuterAlt(_localctx, 1); { + setState(73); + _errHandler.sync(this); + _la = _input.LA(1); + if (_la == MINUS || _la == PLUS) { + { + setState(72); + _la = _input.LA(1); + if (!(_la == MINUS || _la == PLUS)) { + _errHandler.recoverInline(this); + } + else { + if (_input.LA(1) == Token.EOF) { + matchedEOF = true; + } + _errHandler.reportMatch(this); + consume(); + } + } + } + + setState(75); + match(INTEGER_VALUE); + } + break; + case 2: + _localctx = new DecimalConstantContext(_localctx); + enterOuterAlt(_localctx, 2); { + setState(77); + _errHandler.sync(this); + _la = _input.LA(1); + if (_la == MINUS || _la == PLUS) { + { + setState(76); + _la = _input.LA(1); + if (!(_la == MINUS || _la == PLUS)) { + _errHandler.recoverInline(this); + } + else { + if (_input.LA(1) == Token.EOF) { + matchedEOF = true; + } + _errHandler.reportMatch(this); + consume(); + } + } + } + setState(79); + match(DECIMAL_VALUE); + } + break; + case 3: + _localctx = new TextConstantContext(_localctx); + enterOuterAlt(_localctx, 3); { + setState(81); + _errHandler.sync(this); + _alt = 1; + do { + switch (_alt) { + case 1: { + { + setState(80); + match(QUOTED_STRING); + } + } + break; + default: + throw new NoViableAltException(this); + } + setState(83); + _errHandler.sync(this); + _alt = getInterpreter().adaptivePredict(_input, 8, _ctx); + } + while (_alt != 2 && _alt != org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER); + } + break; + case 4: + _localctx = new BooleanConstantContext(_localctx); + enterOuterAlt(_localctx, 4); { + setState(85); + match(BOOLEAN_VALUE); + } + break; + } + } + catch (RecognitionException re) { + _localctx.exception = re; + _errHandler.reportError(this, re); + _errHandler.recover(this, re); + } + finally { + exitRule(); + } + return _localctx; } - @SuppressWarnings("CheckReturnValue") - public static class AndExpressionContext extends BooleanExpressionContext { + public boolean sempred(RuleContext _localctx, int ruleIndex, int predIndex) { + switch (ruleIndex) { + case 1: + return booleanExpression_sempred((BooleanExpressionContext) _localctx, predIndex); + } + return true; + } - public BooleanExpressionContext left; + private boolean booleanExpression_sempred(BooleanExpressionContext _localctx, int predIndex) { + switch (predIndex) { + case 0: + return precpred(_ctx, 4); + case 1: + return precpred(_ctx, 3); + } + return true; + } - public Token operator; + @SuppressWarnings("CheckReturnValue") + public static class WhereContext extends ParserRuleContext { - public BooleanExpressionContext right; + public WhereContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } - public List booleanExpression() { - return getRuleContexts(BooleanExpressionContext.class); + public TerminalNode WHERE() { + return getToken(FiltersParser.WHERE, 0); } - public BooleanExpressionContext booleanExpression(int i) { - return getRuleContext(BooleanExpressionContext.class, i); + public BooleanExpressionContext booleanExpression() { + return getRuleContext(BooleanExpressionContext.class, 0); } - public TerminalNode AND() { - return getToken(FiltersParser.AND, 0); + public TerminalNode EOF() { + return getToken(FiltersParser.EOF, 0); } - public AndExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); + @Override + public int getRuleIndex() { + return RULE_where; } @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterAndExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterWhere(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitAndExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitWhere(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitAndExpression(this); - else + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitWhere(this); + } + else { return visitor.visitChildren(this); + } } } @SuppressWarnings("CheckReturnValue") - public static class InExpressionContext extends BooleanExpressionContext { + public static class BooleanExpressionContext extends ParserRuleContext { - public IdentifierContext identifier() { - return getRuleContext(IdentifierContext.class, 0); + public BooleanExpressionContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); } - public TerminalNode IN() { - return getToken(FiltersParser.IN, 0); + public BooleanExpressionContext() { } - public ConstantArrayContext constantArray() { - return getRuleContext(ConstantArrayContext.class, 0); + @Override + public int getRuleIndex() { + return RULE_booleanExpression; } - public InExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); + public void copyFrom(BooleanExpressionContext ctx) { + super.copyFrom(ctx); } - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterInExpression(this); - } + } - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitInExpression(this); - } + @SuppressWarnings("CheckReturnValue") + public static class NinExpressionContext extends BooleanExpressionContext { - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitInExpression(this); - else - return visitor.visitChildren(this); + public NinExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); } - } + public IdentifierContext identifier() { + return getRuleContext(IdentifierContext.class, 0); + } - @SuppressWarnings("CheckReturnValue") - public static class NotExpressionContext extends BooleanExpressionContext { + public ConstantArrayContext constantArray() { + return getRuleContext(ConstantArrayContext.class, 0); + } public TerminalNode NOT() { return getToken(FiltersParser.NOT, 0); } - public BooleanExpressionContext booleanExpression() { - return getRuleContext(BooleanExpressionContext.class, 0); + public TerminalNode IN() { + return getToken(FiltersParser.IN, 0); } - public NotExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); + public TerminalNode NIN() { + return getToken(FiltersParser.NIN, 0); } @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterNotExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterNinExpression(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitNotExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitNinExpression(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitNotExpression(this); - else + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitNinExpression(this); + } + else { return visitor.visitChildren(this); + } } } @SuppressWarnings("CheckReturnValue") - public static class CompareExpressionContext extends BooleanExpressionContext { + public static class AndExpressionContext extends BooleanExpressionContext { - public IdentifierContext identifier() { - return getRuleContext(IdentifierContext.class, 0); + public BooleanExpressionContext left; + + public Token operator; + + public BooleanExpressionContext right; + + public AndExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); } - public CompareContext compare() { - return getRuleContext(CompareContext.class, 0); + public List booleanExpression() { + return getRuleContexts(BooleanExpressionContext.class); } - public ConstantContext constant() { - return getRuleContext(ConstantContext.class, 0); + public BooleanExpressionContext booleanExpression(int i) { + return getRuleContext(BooleanExpressionContext.class, i); } - public CompareExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); + public TerminalNode AND() { + return getToken(FiltersParser.AND, 0); } @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterCompareExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterAndExpression(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitCompareExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitAndExpression(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitCompareExpression(this); - else + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitAndExpression(this); + } + else { return visitor.visitChildren(this); + } } } @SuppressWarnings("CheckReturnValue") - public static class OrExpressionContext extends BooleanExpressionContext { - - public BooleanExpressionContext left; - - public Token operator; - - public BooleanExpressionContext right; + public static class InExpressionContext extends BooleanExpressionContext { - public List booleanExpression() { - return getRuleContexts(BooleanExpressionContext.class); + public InExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); } - public BooleanExpressionContext booleanExpression(int i) { - return getRuleContext(BooleanExpressionContext.class, i); + public IdentifierContext identifier() { + return getRuleContext(IdentifierContext.class, 0); } - public TerminalNode OR() { - return getToken(FiltersParser.OR, 0); + public TerminalNode IN() { + return getToken(FiltersParser.IN, 0); } - public OrExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); + public ConstantArrayContext constantArray() { + return getRuleContext(ConstantArrayContext.class, 0); } @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterOrExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterInExpression(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitOrExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitInExpression(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitOrExpression(this); - else + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitInExpression(this); + } + else { return visitor.visitChildren(this); + } } } @SuppressWarnings("CheckReturnValue") - public static class GroupExpressionContext extends BooleanExpressionContext { - - public TerminalNode LEFT_PARENTHESIS() { - return getToken(FiltersParser.LEFT_PARENTHESIS, 0); - } + public static class NotExpressionContext extends BooleanExpressionContext { - public BooleanExpressionContext booleanExpression() { - return getRuleContext(BooleanExpressionContext.class, 0); + public NotExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); } - public TerminalNode RIGHT_PARENTHESIS() { - return getToken(FiltersParser.RIGHT_PARENTHESIS, 0); + public TerminalNode NOT() { + return getToken(FiltersParser.NOT, 0); } - public GroupExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); + public BooleanExpressionContext booleanExpression() { + return getRuleContext(BooleanExpressionContext.class, 0); } @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterGroupExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterNotExpression(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitGroupExpression(this); + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitNotExpression(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitGroupExpression(this); - else + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitNotExpression(this); + } + else { return visitor.visitChildren(this); + } } } - public final BooleanExpressionContext booleanExpression() throws RecognitionException { - return booleanExpression(0); - } + @SuppressWarnings("CheckReturnValue") + public static class CompareExpressionContext extends BooleanExpressionContext { - private BooleanExpressionContext booleanExpression(int _p) throws RecognitionException { - ParserRuleContext _parentctx = _ctx; - int _parentState = getState(); - BooleanExpressionContext _localctx = new BooleanExpressionContext(_ctx, _parentState); - BooleanExpressionContext _prevctx = _localctx; - int _startState = 2; - enterRecursionRule(_localctx, 2, RULE_booleanExpression, _p); - try { - int _alt; - enterOuterAlt(_localctx, 1); - { - setState(39); - _errHandler.sync(this); - switch (getInterpreter().adaptivePredict(_input, 1, _ctx)) { - case 1: { - _localctx = new CompareExpressionContext(_localctx); - _ctx = _localctx; - _prevctx = _localctx; + public CompareExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } - setState(17); - identifier(); - setState(18); - compare(); - setState(19); - constant(); - } - break; - case 2: { - _localctx = new InExpressionContext(_localctx); - _ctx = _localctx; - _prevctx = _localctx; - setState(21); - identifier(); - setState(22); - match(IN); - setState(23); - constantArray(); - } - break; - case 3: { - _localctx = new NinExpressionContext(_localctx); - _ctx = _localctx; - _prevctx = _localctx; - setState(25); - identifier(); - setState(29); - _errHandler.sync(this); - switch (_input.LA(1)) { - case NOT: { - setState(26); - match(NOT); - setState(27); - match(IN); - } - break; - case NIN: { - setState(28); - match(NIN); - } - break; - default: - throw new NoViableAltException(this); - } - setState(31); - constantArray(); - } - break; - case 4: { - _localctx = new GroupExpressionContext(_localctx); - _ctx = _localctx; - _prevctx = _localctx; - setState(33); - match(LEFT_PARENTHESIS); - setState(34); - booleanExpression(0); - setState(35); - match(RIGHT_PARENTHESIS); - } - break; - case 5: { - _localctx = new NotExpressionContext(_localctx); - _ctx = _localctx; - _prevctx = _localctx; - setState(37); - match(NOT); - setState(38); - booleanExpression(1); - } - break; - } - _ctx.stop = _input.LT(-1); - setState(49); - _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); - while (_alt != 2 && _alt != org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER) { - if (_alt == 1) { - if (_parseListeners != null) - triggerExitRuleEvent(); - _prevctx = _localctx; - { - setState(47); - _errHandler.sync(this); - switch (getInterpreter().adaptivePredict(_input, 2, _ctx)) { - case 1: { - _localctx = new AndExpressionContext( - new BooleanExpressionContext(_parentctx, _parentState)); - ((AndExpressionContext) _localctx).left = _prevctx; - pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); - setState(41); - if (!(precpred(_ctx, 4))) - throw new FailedPredicateException(this, "precpred(_ctx, 4)"); - setState(42); - ((AndExpressionContext) _localctx).operator = match(AND); - setState(43); - ((AndExpressionContext) _localctx).right = booleanExpression(5); - } - break; - case 2: { - _localctx = new OrExpressionContext( - new BooleanExpressionContext(_parentctx, _parentState)); - ((OrExpressionContext) _localctx).left = _prevctx; - pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); - setState(44); - if (!(precpred(_ctx, 3))) - throw new FailedPredicateException(this, "precpred(_ctx, 3)"); - setState(45); - ((OrExpressionContext) _localctx).operator = match(OR); - setState(46); - ((OrExpressionContext) _localctx).right = booleanExpression(4); - } - break; - } - } - } - setState(51); - _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); - } + public IdentifierContext identifier() { + return getRuleContext(IdentifierContext.class, 0); + } + + public CompareContext compare() { + return getRuleContext(CompareContext.class, 0); + } + + public ConstantContext constant() { + return getRuleContext(ConstantContext.class, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterCompareExpression(this); } } - catch (RecognitionException re) { - _localctx.exception = re; - _errHandler.reportError(this, re); - _errHandler.recover(this, re); + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitCompareExpression(this); + } } - finally { - unrollRecursionContexts(_parentctx); + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitCompareExpression(this); + } + else { + return visitor.visitChildren(this); + } } - return _localctx; + + } + + @SuppressWarnings("CheckReturnValue") + public static class OrExpressionContext extends BooleanExpressionContext { + + public BooleanExpressionContext left; + + public Token operator; + + public BooleanExpressionContext right; + + public OrExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public List booleanExpression() { + return getRuleContexts(BooleanExpressionContext.class); + } + + public BooleanExpressionContext booleanExpression(int i) { + return getRuleContext(BooleanExpressionContext.class, i); + } + + public TerminalNode OR() { + return getToken(FiltersParser.OR, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterOrExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitOrExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitOrExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class GroupExpressionContext extends BooleanExpressionContext { + + public GroupExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public TerminalNode LEFT_PARENTHESIS() { + return getToken(FiltersParser.LEFT_PARENTHESIS, 0); + } + + public BooleanExpressionContext booleanExpression() { + return getRuleContext(BooleanExpressionContext.class, 0); + } + + public TerminalNode RIGHT_PARENTHESIS() { + return getToken(FiltersParser.RIGHT_PARENTHESIS, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterGroupExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitGroupExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitGroupExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + } @SuppressWarnings("CheckReturnValue") public static class ConstantArrayContext extends ParserRuleContext { + public ConstantArrayContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } + public TerminalNode LEFT_SQUARE_BRACKETS() { return getToken(FiltersParser.LEFT_SQUARE_BRACKETS, 0); } @@ -718,10 +1067,6 @@ public TerminalNode COMMA(int i) { return getToken(FiltersParser.COMMA, i); } - public ConstantArrayContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - @Override public int getRuleIndex() { return RULE_constantArray; @@ -729,71 +1074,37 @@ public int getRuleIndex() { @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).enterConstantArray(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).exitConstantArray(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) + if (visitor instanceof FiltersVisitor) { return ((FiltersVisitor) visitor).visitConstantArray(this); - else + } + else { return visitor.visitChildren(this); - } - - } - - public final ConstantArrayContext constantArray() throws RecognitionException { - ConstantArrayContext _localctx = new ConstantArrayContext(_ctx, getState()); - enterRule(_localctx, 4, RULE_constantArray); - int _la; - try { - enterOuterAlt(_localctx, 1); - { - setState(52); - match(LEFT_SQUARE_BRACKETS); - setState(53); - constant(); - setState(58); - _errHandler.sync(this); - _la = _input.LA(1); - while (_la == COMMA) { - { - { - setState(54); - match(COMMA); - setState(55); - constant(); - } - } - setState(60); - _errHandler.sync(this); - _la = _input.LA(1); - } - setState(61); - match(RIGHT_SQUARE_BRACKETS); } } - catch (RecognitionException re) { - _localctx.exception = re; - _errHandler.reportError(this, re); - _errHandler.recover(this, re); - } - finally { - exitRule(); - } - return _localctx; + } @SuppressWarnings("CheckReturnValue") public static class CompareContext extends ParserRuleContext { + public CompareContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } + public TerminalNode EQUALS() { return getToken(FiltersParser.EQUALS, 0); } @@ -818,10 +1129,6 @@ public TerminalNode NE() { return getToken(FiltersParser.NE, 0); } - public CompareContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - @Override public int getRuleIndex() { return RULE_compare; @@ -829,60 +1136,37 @@ public int getRuleIndex() { @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).enterCompare(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).exitCompare(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) + if (visitor instanceof FiltersVisitor) { return ((FiltersVisitor) visitor).visitCompare(this); - else + } + else { return visitor.visitChildren(this); - } - - } - - public final CompareContext compare() throws RecognitionException { - CompareContext _localctx = new CompareContext(_ctx, getState()); - enterRule(_localctx, 6, RULE_compare); - int _la; - try { - enterOuterAlt(_localctx, 1); - { - setState(63); - _la = _input.LA(1); - if (!((((_la) & ~0x3f) == 0 && ((1L << _la) & 63744L) != 0))) { - _errHandler.recoverInline(this); - } - else { - if (_input.LA(1) == Token.EOF) - matchedEOF = true; - _errHandler.reportMatch(this); - consume(); - } } } - catch (RecognitionException re) { - _localctx.exception = re; - _errHandler.reportError(this, re); - _errHandler.recover(this, re); - } - finally { - exitRule(); - } - return _localctx; + } @SuppressWarnings("CheckReturnValue") public static class IdentifierContext extends ParserRuleContext { + public IdentifierContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } + public List IDENTIFIER() { return getTokens(FiltersParser.IDENTIFIER); } @@ -899,10 +1183,6 @@ public TerminalNode QUOTED_STRING() { return getToken(FiltersParser.QUOTED_STRING, 0); } - public IdentifierContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - @Override public int getRuleIndex() { return RULE_identifier; @@ -910,66 +1190,28 @@ public int getRuleIndex() { @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).enterIdentifier(this); + } } @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitIdentifier(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitIdentifier(this); - else - return visitor.visitChildren(this); - } - - } - - public final IdentifierContext identifier() throws RecognitionException { - IdentifierContext _localctx = new IdentifierContext(_ctx, getState()); - enterRule(_localctx, 8, RULE_identifier); - try { - setState(70); - _errHandler.sync(this); - switch (getInterpreter().adaptivePredict(_input, 5, _ctx)) { - case 1: - enterOuterAlt(_localctx, 1); { - setState(65); - match(IDENTIFIER); - setState(66); - match(DOT); - setState(67); - match(IDENTIFIER); - } - break; - case 2: - enterOuterAlt(_localctx, 2); { - setState(68); - match(IDENTIFIER); - } - break; - case 3: - enterOuterAlt(_localctx, 3); { - setState(69); - match(QUOTED_STRING); - } - break; + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitIdentifier(this); } } - catch (RecognitionException re) { - _localctx.exception = re; - _errHandler.reportError(this, re); - _errHandler.recover(this, re); - } - finally { - exitRule(); + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitIdentifier(this); + } + else { + return visitor.visitChildren(this); + } } - return _localctx; + } @SuppressWarnings("CheckReturnValue") @@ -979,14 +1221,14 @@ public ConstantContext(ParserRuleContext parent, int invokingState) { super(parent, invokingState); } + public ConstantContext() { + } + @Override public int getRuleIndex() { return RULE_constant; } - public ConstantContext() { - } - public void copyFrom(ConstantContext ctx) { super.copyFrom(ctx); } @@ -996,6 +1238,10 @@ public void copyFrom(ConstantContext ctx) { @SuppressWarnings("CheckReturnValue") public static class DecimalConstantContext extends ConstantContext { + public DecimalConstantContext(ConstantContext ctx) { + copyFrom(ctx); + } + public TerminalNode DECIMAL_VALUE() { return getToken(FiltersParser.DECIMAL_VALUE, 0); } @@ -1008,28 +1254,28 @@ public TerminalNode PLUS() { return getToken(FiltersParser.PLUS, 0); } - public DecimalConstantContext(ConstantContext ctx) { - copyFrom(ctx); - } - @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).enterDecimalConstant(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).exitDecimalConstant(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) + if (visitor instanceof FiltersVisitor) { return ((FiltersVisitor) visitor).visitDecimalConstant(this); - else + } + else { return visitor.visitChildren(this); + } } } @@ -1037,6 +1283,10 @@ public T accept(ParseTreeVisitor visitor) { @SuppressWarnings("CheckReturnValue") public static class TextConstantContext extends ConstantContext { + public TextConstantContext(ConstantContext ctx) { + copyFrom(ctx); + } + public List QUOTED_STRING() { return getTokens(FiltersParser.QUOTED_STRING); } @@ -1045,28 +1295,28 @@ public TerminalNode QUOTED_STRING(int i) { return getToken(FiltersParser.QUOTED_STRING, i); } - public TextConstantContext(ConstantContext ctx) { - copyFrom(ctx); - } - @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).enterTextConstant(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).exitTextConstant(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) + if (visitor instanceof FiltersVisitor) { return ((FiltersVisitor) visitor).visitTextConstant(this); - else + } + else { return visitor.visitChildren(this); + } } } @@ -1074,32 +1324,36 @@ public T accept(ParseTreeVisitor visitor) { @SuppressWarnings("CheckReturnValue") public static class BooleanConstantContext extends ConstantContext { - public TerminalNode BOOLEAN_VALUE() { - return getToken(FiltersParser.BOOLEAN_VALUE, 0); - } - public BooleanConstantContext(ConstantContext ctx) { copyFrom(ctx); } + public TerminalNode BOOLEAN_VALUE() { + return getToken(FiltersParser.BOOLEAN_VALUE, 0); + } + @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).enterBooleanConstant(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).exitBooleanConstant(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) + if (visitor instanceof FiltersVisitor) { return ((FiltersVisitor) visitor).visitBooleanConstant(this); - else + } + else { return visitor.visitChildren(this); + } } } @@ -1107,6 +1361,10 @@ public T accept(ParseTreeVisitor visitor) { @SuppressWarnings("CheckReturnValue") public static class IntegerConstantContext extends ConstantContext { + public IntegerConstantContext(ConstantContext ctx) { + copyFrom(ctx); + } + public TerminalNode INTEGER_VALUE() { return getToken(FiltersParser.INTEGER_VALUE, 0); } @@ -1119,218 +1377,30 @@ public TerminalNode PLUS() { return getToken(FiltersParser.PLUS, 0); } - public IntegerConstantContext(ConstantContext ctx) { - copyFrom(ctx); - } - @Override public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).enterIntegerConstant(this); + } } @Override public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) + if (listener instanceof FiltersListener) { ((FiltersListener) listener).exitIntegerConstant(this); + } } @Override public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) + if (visitor instanceof FiltersVisitor) { return ((FiltersVisitor) visitor).visitIntegerConstant(this); - else + } + else { return visitor.visitChildren(this); - } - - } - - public final ConstantContext constant() throws RecognitionException { - ConstantContext _localctx = new ConstantContext(_ctx, getState()); - enterRule(_localctx, 10, RULE_constant); - int _la; - try { - int _alt; - setState(86); - _errHandler.sync(this); - switch (getInterpreter().adaptivePredict(_input, 9, _ctx)) { - case 1: - _localctx = new IntegerConstantContext(_localctx); - enterOuterAlt(_localctx, 1); { - setState(73); - _errHandler.sync(this); - _la = _input.LA(1); - if (_la == MINUS || _la == PLUS) { - { - setState(72); - _la = _input.LA(1); - if (!(_la == MINUS || _la == PLUS)) { - _errHandler.recoverInline(this); - } - else { - if (_input.LA(1) == Token.EOF) - matchedEOF = true; - _errHandler.reportMatch(this); - consume(); - } - } - } - - setState(75); - match(INTEGER_VALUE); - } - break; - case 2: - _localctx = new DecimalConstantContext(_localctx); - enterOuterAlt(_localctx, 2); { - setState(77); - _errHandler.sync(this); - _la = _input.LA(1); - if (_la == MINUS || _la == PLUS) { - { - setState(76); - _la = _input.LA(1); - if (!(_la == MINUS || _la == PLUS)) { - _errHandler.recoverInline(this); - } - else { - if (_input.LA(1) == Token.EOF) - matchedEOF = true; - _errHandler.reportMatch(this); - consume(); - } - } - } - - setState(79); - match(DECIMAL_VALUE); - } - break; - case 3: - _localctx = new TextConstantContext(_localctx); - enterOuterAlt(_localctx, 3); { - setState(81); - _errHandler.sync(this); - _alt = 1; - do { - switch (_alt) { - case 1: { - { - setState(80); - match(QUOTED_STRING); - } - } - break; - default: - throw new NoViableAltException(this); - } - setState(83); - _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input, 8, _ctx); - } - while (_alt != 2 && _alt != org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER); - } - break; - case 4: - _localctx = new BooleanConstantContext(_localctx); - enterOuterAlt(_localctx, 4); { - setState(85); - match(BOOLEAN_VALUE); - } - break; } } - catch (RecognitionException re) { - _localctx.exception = re; - _errHandler.reportError(this, re); - _errHandler.recover(this, re); - } - finally { - exitRule(); - } - return _localctx; - } - - public boolean sempred(RuleContext _localctx, int ruleIndex, int predIndex) { - switch (ruleIndex) { - case 1: - return booleanExpression_sempred((BooleanExpressionContext) _localctx, predIndex); - } - return true; - } - - private boolean booleanExpression_sempred(BooleanExpressionContext _localctx, int predIndex) { - switch (predIndex) { - case 0: - return precpred(_ctx, 4); - case 1: - return precpred(_ctx, 3); - } - return true; - } - - public static final String _serializedATN = "\u0004\u0001\u001aY\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002" - + "\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004\u0007\u0004\u0002" - + "\u0005\u0007\u0005\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0003\u0001\u001e\b\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0003\u0001(\b" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0005\u00010\b\u0001\n\u0001\f\u00013\t\u0001\u0001\u0002\u0001" - + "\u0002\u0001\u0002\u0001\u0002\u0005\u00029\b\u0002\n\u0002\f\u0002<\t" - + "\u0002\u0001\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001" - + "\u0004\u0001\u0004\u0001\u0004\u0001\u0004\u0003\u0004G\b\u0004\u0001" - + "\u0005\u0003\u0005J\b\u0005\u0001\u0005\u0001\u0005\u0003\u0005N\b\u0005" - + "\u0001\u0005\u0001\u0005\u0004\u0005R\b\u0005\u000b\u0005\f\u0005S\u0001" - + "\u0005\u0003\u0005W\b\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000" - + "\u0002\u0004\u0006\b\n\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000" - + "\t\nb\u0000\f\u0001\u0000\u0000\u0000\u0002\'\u0001\u0000\u0000\u0000" - + "\u00044\u0001\u0000\u0000\u0000\u0006?\u0001\u0000\u0000\u0000\bF\u0001" - + "\u0000\u0000\u0000\nV\u0001\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000" - + "\r\u000e\u0003\u0002\u0001\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f" - + "\u0001\u0001\u0000\u0000\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000" - + "\u0011\u0012\u0003\b\u0004\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013" - + "\u0014\u0003\n\u0005\u0000\u0014(\u0001\u0000\u0000\u0000\u0015\u0016" - + "\u0003\b\u0004\u0000\u0016\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003" - + "\u0004\u0002\u0000\u0018(\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b" - + "\u0004\u0000\u001a\u001b\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012" - + "\u0000\u0000\u001c\u001e\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000" - + "\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000" - + "\u0000\u0000\u001f \u0003\u0004\u0002\u0000 (\u0001\u0000\u0000\u0000" - + "!\"\u0005\u0006\u0000\u0000\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000" - + "\u0000$(\u0001\u0000\u0000\u0000%&\u0005\u0014\u0000\u0000&(\u0003\u0002" - + "\u0001\u0001\'\u0010\u0001\u0000\u0000\u0000\'\u0015\u0001\u0000\u0000" - + "\u0000\'\u0019\u0001\u0000\u0000\u0000\'!\u0001\u0000\u0000\u0000\'%\u0001" - + "\u0000\u0000\u0000(1\u0001\u0000\u0000\u0000)*\n\u0004\u0000\u0000*+\u0005" - + "\u0010\u0000\u0000+0\u0003\u0002\u0001\u0005,-\n\u0003\u0000\u0000-.\u0005" - + "\u0011\u0000\u0000.0\u0003\u0002\u0001\u0004/)\u0001\u0000\u0000\u0000" - + "/,\u0001\u0000\u0000\u000003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000" - + "\u000012\u0001\u0000\u0000\u00002\u0003\u0001\u0000\u0000\u000031\u0001" - + "\u0000\u0000\u000045\u0005\u0004\u0000\u00005:\u0003\n\u0005\u000067\u0005" - + "\u0003\u0000\u000079\u0003\n\u0005\u000086\u0001\u0000\u0000\u00009<\u0001" - + "\u0000\u0000\u0000:8\u0001\u0000\u0000\u0000:;\u0001\u0000\u0000\u0000" - + ";=\u0001\u0000\u0000\u0000<:\u0001\u0000\u0000\u0000=>\u0005\u0005\u0000" - + "\u0000>\u0005\u0001\u0000\u0000\u0000?@\u0007\u0000\u0000\u0000@\u0007" - + "\u0001\u0000\u0000\u0000AB\u0005\u0019\u0000\u0000BC\u0005\u0002\u0000" - + "\u0000CG\u0005\u0019\u0000\u0000DG\u0005\u0019\u0000\u0000EG\u0005\u0016" - + "\u0000\u0000FA\u0001\u0000\u0000\u0000FD\u0001\u0000\u0000\u0000FE\u0001" - + "\u0000\u0000\u0000G\t\u0001\u0000\u0000\u0000HJ\u0007\u0001\u0000\u0000" - + "IH\u0001\u0000\u0000\u0000IJ\u0001\u0000\u0000\u0000JK\u0001\u0000\u0000" - + "\u0000KW\u0005\u0017\u0000\u0000LN\u0007\u0001\u0000\u0000ML\u0001\u0000" - + "\u0000\u0000MN\u0001\u0000\u0000\u0000NO\u0001\u0000\u0000\u0000OW\u0005" - + "\u0018\u0000\u0000PR\u0005\u0016\u0000\u0000QP\u0001\u0000\u0000\u0000" - + "RS\u0001\u0000\u0000\u0000SQ\u0001\u0000\u0000\u0000ST\u0001\u0000\u0000" - + "\u0000TW\u0001\u0000\u0000\u0000UW\u0005\u0015\u0000\u0000VI\u0001\u0000" - + "\u0000\u0000VM\u0001\u0000\u0000\u0000VQ\u0001\u0000\u0000\u0000VU\u0001" - + "\u0000\u0000\u0000W\u000b\u0001\u0000\u0000\u0000\n\u001d\'/1:FIMSV"; - public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); - static { - _decisionToDFA = new DFA[_ATN.getNumberOfDecisions()]; - for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) { - _decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i); - } } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java index 3f099b18229..887159c2b73 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,31 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + package org.springframework.ai.vectorstore.filter.antlr4; -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + +import org.antlr.v4.runtime.tree.ParseTreeVisitor; // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ -import org.antlr.v4.runtime.tree.ParseTreeVisitor; - /** * This interface defines a complete generic visitor for a parse tree produced by * {@link FiltersParser}. @@ -163,4 +149,4 @@ public interface FiltersVisitor extends ParseTreeVisitor { */ T visitBooleanConstant(FiltersParser.BooleanConstantContext ctx); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java index b3fbeda677c..808e790e8fb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter.converter; import java.util.List; import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.filter.FilterHelper; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Operand; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.filter.FilterHelper; /** * AbstractFilterExpressionConverter is an abstract class that implements the diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverter.java index 4f8c6c061cd..64877fc24a5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter.converter; import org.springframework.ai.vectorstore.filter.Filter.Expression; @@ -60,4 +61,4 @@ protected void doKey(Key key, StringBuilder context) { context.append("\"" + identifier + "\": "); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PrintFilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PrintFilterExpressionConverter.java index b2e93fcf746..14d2d1216d5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PrintFilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PrintFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter.converter; import org.springframework.ai.vectorstore.filter.Filter.Expression; @@ -47,4 +48,4 @@ public void doEndGroup(Group group, StringBuilder context) { context.append(")"); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java index ee4e98a6bc3..025f8e600fe 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java @@ -1,30 +1,31 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore.observation; import java.util.List; import java.util.Optional; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.lang.Nullable; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @since 1.0.0 @@ -53,7 +54,7 @@ public void add(List documents) { VectorStoreObservationDocumentation.AI_VECTOR_STORE .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - observationRegistry) + this.observationRegistry) .observe(() -> this.doAdd(documents)); } @@ -96,4 +97,4 @@ public List similaritySearch(SearchRequest request) { public abstract VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConvention.java index cfddd211c62..15700d35723 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConvention.java @@ -1,29 +1,30 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore.observation; +import io.micrometer.common.KeyValue; +import io.micrometer.common.KeyValues; + import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; -import io.micrometer.common.KeyValue; -import io.micrometer.common.KeyValues; - /** * Default conventions to populate observations for vector store operations. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContentProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContentProcessor.java index cb834f586f8..8513f81d6d9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContentProcessor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContentProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.observation; +import java.util.List; + import org.springframework.ai.document.Document; import org.springframework.util.CollectionUtils; -import java.util.List; - /** * Utilities to process the query content in observations for vector store operations. * @@ -27,6 +28,9 @@ */ public final class VectorStoreObservationContentProcessor { + private VectorStoreObservationContentProcessor() { + } + public static List documents(VectorStoreObservationContext context) { if (CollectionUtils.isEmpty(context.getQueryResponse())) { return List.of(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContext.java index d12dd55adc1..07da0990da1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContext.java @@ -1,29 +1,30 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore.observation; import java.util.List; +import io.micrometer.observation.Observation; + import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import io.micrometer.observation.Observation; - /** * Context used to store metadata for vector store operations. * @@ -33,37 +34,10 @@ */ public class VectorStoreObservationContext extends Observation.Context { - public enum Operation { - - /** - * VectorStore delete operation. - */ - ADD("add"), - /** - * VectorStore add operation. - */ - DELETE("delete"), - /** - * VectorStore similarity search operation. - */ - QUERY("query"); - - public final String value; - - Operation(String value) { - this.value = value; - } - - public String value() { - return this.value; - } - - } + private final String databaseSystem; // COMMON - private final String databaseSystem; - private final String operationName; @Nullable @@ -81,11 +55,11 @@ public String value() { @Nullable private String similarityMetric; - // SEARCH - @Nullable private SearchRequest queryRequest; + // SEARCH + @Nullable private List queryResponse; @@ -96,6 +70,14 @@ public VectorStoreObservationContext(String databaseSystem, String operationName this.operationName = operationName; } + public static Builder builder(String databaseSystem, String operationName) { + return new Builder(databaseSystem, operationName); + } + + public static Builder builder(String databaseSystem, Operation operation) { + return builder(databaseSystem, operation.value); + } + public String getDatabaseSystem() { return this.databaseSystem; } @@ -106,7 +88,7 @@ public String getOperationName() { @Nullable public String getCollectionName() { - return collectionName; + return this.collectionName; } public void setCollectionName(@Nullable String collectionName) { @@ -115,7 +97,7 @@ public void setCollectionName(@Nullable String collectionName) { @Nullable public Integer getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(@Nullable Integer dimensions) { @@ -124,7 +106,7 @@ public void setDimensions(@Nullable Integer dimensions) { @Nullable public String getFieldName() { - return fieldName; + return this.fieldName; } public void setFieldName(@Nullable String fieldName) { @@ -133,7 +115,7 @@ public void setFieldName(@Nullable String fieldName) { @Nullable public String getNamespace() { - return namespace; + return this.namespace; } public void setNamespace(@Nullable String namespace) { @@ -142,7 +124,7 @@ public void setNamespace(@Nullable String namespace) { @Nullable public String getSimilarityMetric() { - return similarityMetric; + return this.similarityMetric; } public void setSimilarityMetric(@Nullable String similarityMetric) { @@ -151,7 +133,7 @@ public void setSimilarityMetric(@Nullable String similarityMetric) { @Nullable public SearchRequest getQueryRequest() { - return queryRequest; + return this.queryRequest; } public void setQueryRequest(@Nullable SearchRequest queryRequest) { @@ -160,19 +142,38 @@ public void setQueryRequest(@Nullable SearchRequest queryRequest) { @Nullable public List getQueryResponse() { - return queryResponse; + return this.queryResponse; } public void setQueryResponse(@Nullable List queryResponse) { this.queryResponse = queryResponse; } - public static Builder builder(String databaseSystem, String operationName) { - return new Builder(databaseSystem, operationName); - } + public enum Operation { + + /** + * VectorStore delete operation. + */ + ADD("add"), + /** + * VectorStore add operation. + */ + DELETE("delete"), + /** + * VectorStore similarity search operation. + */ + QUERY("query"); + + public final String value; + + Operation(String value) { + this.value = value; + } + + public String value() { + return this.value; + } - public static Builder builder(String databaseSystem, Operation operation) { - return builder(databaseSystem, operation.value); } public static class Builder { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationConvention.java index 9bf80d838d2..38a64d3771a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationConvention.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore.observation; import io.micrometer.observation.Observation; @@ -30,4 +31,4 @@ default boolean supportsContext(Observation.Context context) { return context instanceof VectorStoreObservationContext; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationDocumentation.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationDocumentation.java index f56ead4def5..f351ca29253 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationDocumentation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationDocumentation.java @@ -1,27 +1,28 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.vectorstore.observation; + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.springframework.ai.observation.conventions.VectorStoreObservationAttributes; +package org.springframework.ai.vectorstore.observation; import io.micrometer.common.docs.KeyName; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationConvention; import io.micrometer.observation.docs.ObservationDocumentation; +import org.springframework.ai.observation.conventions.VectorStoreObservationAttributes; + /** * Documented conventions for vector store observations. * @@ -85,7 +86,7 @@ public String asString() { public String asString() { return VectorStoreObservationAttributes.DB_SYSTEM.value(); } - }; + } } @@ -200,7 +201,7 @@ public String asString() { public String asString() { return VectorStoreObservationAttributes.DB_VECTOR_QUERY_TOP_K.value(); } - }; + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilter.java index 4beab2b4f6a..a601acc3bb9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.observation; -import org.springframework.ai.observation.tracing.TracingHelper; +package org.springframework.ai.vectorstore.observation; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationFilter; + +import org.springframework.ai.observation.tracing.TracingHelper; import org.springframework.util.CollectionUtils; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandler.java index 1e46710d714..9dbbefc8cab 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.observation; import io.micrometer.observation.Observation; @@ -21,6 +22,7 @@ import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; + import org.springframework.ai.observation.conventions.VectorStoreObservationAttributes; import org.springframework.ai.observation.conventions.VectorStoreObservationEventNames; import org.springframework.ai.observation.tracing.TracingHelper; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/package-info.java index a7e006093ab..0fd62c25bf3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,4 +19,4 @@ package org.springframework.ai.vectorstore.observation; import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; \ No newline at end of file +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/writer/FileDocumentWriter.java b/spring-ai-core/src/main/java/org/springframework/ai/writer/FileDocumentWriter.java index 023cfa6a1d9..8971ef22dc8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/writer/FileDocumentWriter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/writer/FileDocumentWriter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.writer; import java.io.FileWriter; diff --git a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties index b6fb61c96f3..85a5447d308 100644 --- a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties +++ b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties @@ -1,3 +1,18 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Map of embedding generative names and their dimensions text-embedding-ada-002=1536 text-similarity-ada-001=1024 diff --git a/spring-ai-core/src/test/java/org/springframework/ai/TestConfiguration.java b/spring-ai-core/src/test/java/org/springframework/ai/TestConfiguration.java index 203ac200ff3..582552998d0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/TestConfiguration.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/TestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai; import org.springframework.boot.SpringBootConfiguration; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/aot/AiRuntimeHintsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/aot/AiRuntimeHintsTests.java index 02addaecfeb..97df43159ed 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/aot/AiRuntimeHintsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/aot/AiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; import java.util.Set; @@ -28,16 +29,20 @@ class AiRuntimeHintsTests { + @Test + void discoverRelevantClasses() throws Exception { + var classes = AiRuntimeHints.findJsonAnnotatedClassesInPackage(TestApi.class); + var included = Set.of(TestApi.Bar.class, TestApi.Foo.class) + .stream() + .map(t -> TypeReference.of(t.getName())) + .collect(Collectors.toSet()); + LogFactory.getLog(getClass()).info(classes); + Assert.state(classes.containsAll(included), "there should be all of the enumerated classes. "); + } + @JsonInclude static class TestApi { - static class FooBar { - - } - - record Foo(@JsonProperty("name") String name) { - } - @JsonInclude enum Bar { @@ -45,17 +50,14 @@ enum Bar { } - } + static class FooBar { + + } + + record Foo(@JsonProperty("name") String name) { + + } - @Test - void discoverRelevantClasses() throws Exception { - var classes = AiRuntimeHints.findJsonAnnotatedClassesInPackage(TestApi.class); - var included = Set.of(TestApi.Bar.class, TestApi.Foo.class) - .stream() - .map(t -> TypeReference.of(t.getName())) - .collect(Collectors.toSet()); - LogFactory.getLog(getClass()).info(classes); - Assert.state(classes.containsAll(included), "there should be all of the enumerated classes. "); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/aot/KnuddelsRuntimeHintsTest.java b/spring-ai-core/src/test/java/org/springframework/ai/aot/KnuddelsRuntimeHintsTest.java index eb45821b8d5..409c2329ed3 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/aot/KnuddelsRuntimeHintsTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/aot/KnuddelsRuntimeHintsTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; import org.junit.jupiter.api.Test; + import org.springframework.aot.hint.RuntimeHints; import static org.assertj.core.api.Assertions.assertThat; @@ -31,4 +33,4 @@ void knuddels() { assertThat(runtimeHints).matches(resource().forResource("com/knuddels/jtokkit/cl100k_base.tiktoken")); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/aot/SpringAiCoreRuntimeHintsTest.java b/spring-ai-core/src/test/java/org/springframework/ai/aot/SpringAiCoreRuntimeHintsTest.java index 372d6f076ae..b379c9bb159 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/aot/SpringAiCoreRuntimeHintsTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/aot/SpringAiCoreRuntimeHintsTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; import org.junit.jupiter.api.Test; @@ -38,4 +39,4 @@ void core() { assertThat(runtimeHints).matches(reflection().onMethod(FunctionCallback.class, "getName")); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java index 5ed9ccc8f06..7ab6abb255b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.chat.prompt.ChatOptions; @@ -30,6 +30,8 @@ import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.model.function.FunctionCallingOptions; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit Tests for {@link Prompt}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatModelTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatModelTests.java index f3be7deafd8..27568b6be17 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatModelTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatModelTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,28 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -import org.junit.jupiter.api.Test; - -import org.mockito.Mockito; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; /** * Unit Tests for {@link ChatModel}. @@ -53,15 +53,15 @@ void generateWithStringCallsGenerateWithPromptAndReturnsResponseCorrectly() { ChatModel mockClient = Mockito.mock(ChatModel.class); AssistantMessage mockAssistantMessage = Mockito.mock(AssistantMessage.class); - when(mockAssistantMessage.getContent()).thenReturn(responseMessage); + given(mockAssistantMessage.getContent()).willReturn(responseMessage); // Create a mock Generation Generation generation = Mockito.mock(Generation.class); - when(generation.getOutput()).thenReturn(mockAssistantMessage); + given(generation.getOutput()).willReturn(mockAssistantMessage); // Create a mock ChatResponse with the mock Generation ChatResponse response = Mockito.mock(ChatResponse.class); - when(response.getResult()).thenReturn(generation); + given(response.getResult()).willReturn(generation); // Generation generation = spy(new Generation(responseMessage)); // ChatResponse response = spy(new @@ -69,16 +69,14 @@ void generateWithStringCallsGenerateWithPromptAndReturnsResponseCorrectly() { doCallRealMethod().when(mockClient).call(anyString()); - doAnswer(invocationOnMock -> { - + given(mockClient.call(any(Prompt.class))).willAnswer(invocationOnMock -> { Prompt prompt = invocationOnMock.getArgument(0); assertThat(prompt).isNotNull(); assertThat(prompt.getContents()).isEqualTo(userMessage); return response; - - }).when(mockClient).call(any(Prompt.class)); + }); assertThat(mockClient.call(userMessage)).isEqualTo(responseMessage); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java index 251dd184ecb..07d77ecb1cd 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.ai.chat.client; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - import java.util.List; import java.util.stream.Collectors; @@ -28,6 +25,8 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.InMemoryChatMemory; @@ -35,14 +34,14 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.prompt.Prompt; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -71,15 +70,15 @@ public void promptChatMemory() { .withKeyValue("system-fingerprint", "john doe"); ChatResponseMetadata chatResponseMetadata = builder.build(); - when(chatModel.call(promptCaptor.capture())) - .thenReturn( + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn( new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))), chatResponseMetadata)) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))), + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))), chatResponseMetadata)); ChatMemory chatMemory = new InMemoryChatMemory(); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory)) .build(); @@ -89,7 +88,7 @@ public void promptChatMemory() { String content = chatResponse.getResult().getOutput().getContent(); assertThat(content).isEqualTo("Hello John"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. @@ -101,14 +100,14 @@ public void promptChatMemory() { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - Message userMessage = promptCaptor.getValue().getInstructions().get(1); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("my name is John"); content = chatClient.prompt().user("What is my name?").call().content(); assertThat(content).isEqualTo("Your name is John"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. @@ -122,20 +121,20 @@ public void promptChatMemory() { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - userMessage = promptCaptor.getValue().getInstructions().get(1); + userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("What is my name?"); } @Test public void streamingPromptChatMemory() { - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John")))), (state, sink) -> { sink.next(state); sink.complete(); return state; })) - .thenReturn(Flux.generate( + .willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John")))), (state, sink) -> { sink.next(state); @@ -145,7 +144,7 @@ public void streamingPromptChatMemory() { ChatMemory chatMemory = new InMemoryChatMemory(); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory)) .build(); @@ -154,7 +153,7 @@ public void streamingPromptChatMemory() { assertThat(content).isEqualTo("Hello John"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. @@ -166,14 +165,14 @@ public void streamingPromptChatMemory() { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - Message userMessage = promptCaptor.getValue().getInstructions().get(1); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("my name is John"); content = join(chatClient.prompt().user("What is my name?").stream().content()); assertThat(content).isEqualTo("Your name is John"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. @@ -187,7 +186,7 @@ public void streamingPromptChatMemory() { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - userMessage = promptCaptor.getValue().getInstructions().get(1); + userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("What is my name?"); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java index 2e40f9def72..5295a213acf 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ import org.springframework.core.ParameterizedTypeReference; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -51,9 +51,6 @@ public class ChatClientResponseEntityTests { @Captor ArgumentCaptor promptCaptor; - record MyBean(String name, int age) { - } - @Test public void responseEntityTest() { @@ -63,9 +60,9 @@ public void responseEntityTest() { {"name":"John", "age":30} """)), metadata); - when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse); - ResponseEntity responseEntity = ChatClient.builder(chatModel) + ResponseEntity responseEntity = ChatClient.builder(this.chatModel) .build() .prompt() .user("Tell me about John") @@ -77,7 +74,7 @@ public void responseEntityTest() { assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30)); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getContent()).contains("Tell me about John"); } @@ -87,26 +84,27 @@ public void parametrizedResponseEntityTest() { var chatResponse = new ChatResponse(List.of(new Generation(""" [ - {"name":"Max", "age":10}, - {"name":"Adi", "age":13} + {"name":"Max", "age":10}, + {"name":"Adi", "age":13} ] """))); - when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse); - ResponseEntity> responseEntity = ChatClient.builder(chatModel) + ResponseEntity> responseEntity = ChatClient.builder(this.chatModel) .build() .prompt() .user("Tell me about them") .call() .responseEntity(new ParameterizedTypeReference>() { + }); assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); assertThat(responseEntity.getEntity().get(0)).isEqualTo(new MyBean("Max", 10)); assertThat(responseEntity.getEntity().get(1)).isEqualTo(new MyBean("Adi", 13)); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getContent()).contains("Tell me about them"); } @@ -115,12 +113,12 @@ public void parametrizedResponseEntityTest() { public void customSoCResponseEntityTest() { var chatResponse = new ChatResponse(List.of(new Generation(""" - {"name":"Max", "age":10}, + {"name":"Max", "age":10}, """))); - when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse); - ResponseEntity> responseEntity = ChatClient.builder(chatModel) + ResponseEntity> responseEntity = ChatClient.builder(this.chatModel) .build() .prompt() .user("Tell me about Max") @@ -131,9 +129,13 @@ public void customSoCResponseEntityTest() { assertThat(responseEntity.getEntity().get("name")).isEqualTo("Max"); assertThat(responseEntity.getEntity().get("age")).isEqualTo(10); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getContent()).contains("Tell me about Max"); } + record MyBean(String name, int age) { + + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 7033cdd5a55..3dc853b2fa6 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -54,6 +54,14 @@ @ExtendWith(MockitoExtension.class) public class ChatClientTest { + static Function mockFunction = new Function() { + + @Override + public String apply(String s) { + return s; + } + }; + @Mock ChatModel chatModel; @@ -68,23 +76,23 @@ private String join(Flux fluxContent) { @Test public void defaultSystemText() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { sink.next(state); sink.complete(); return state; })); - var chatClient = ChatClient.builder(chatModel).defaultSystem("Default system text").build(); + var chatClient = ChatClient.builder(this.chatModel).defaultSystem("Default system text").build(); var content = chatClient.prompt().call().content(); assertThat(content).isEqualTo("response"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -92,7 +100,7 @@ public void defaultSystemText() { assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -100,7 +108,7 @@ public void defaultSystemText() { content = chatClient.prompt().system("Override default system text").call().content(); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Override default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -108,7 +116,7 @@ public void defaultSystemText() { content = join(chatClient.prompt().system("Override default system text").stream().content()); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Override default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } @@ -116,17 +124,17 @@ public void defaultSystemText() { @Test public void defaultSystemTextLambda() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { sink.next(state); sink.complete(); return state; })); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) @@ -136,7 +144,7 @@ public void defaultSystemTextLambda() { assertThat(content).isEqualTo("response"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text value1, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -145,7 +153,7 @@ public void defaultSystemTextLambda() { assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text value1, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -153,7 +161,7 @@ public void defaultSystemTextLambda() { content = chatClient.prompt().system(s -> s.param("param1", "value1New")).call().content(); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text value1New, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -161,7 +169,7 @@ public void defaultSystemTextLambda() { content = join(chatClient.prompt().system(s -> s.param("param1", "value1New")).stream().content()); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text value1New, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -172,7 +180,7 @@ public void defaultSystemTextLambda() { .content(); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Override default system text value3"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -183,28 +191,21 @@ public void defaultSystemTextLambda() { .content()); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Override default system text value3"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } - static Function mockFunction = new Function() { - @Override - public String apply(String s) { - return s; - } - }; - @Test public void mutateDefaults() { PortableFunctionCallingOptions options = new FunctionCallingOptionsBuilder().build(); - when(chatModel.getDefaultOptions()).thenReturn(options); + given(this.chatModel.getDefaultOptions()).willReturn(options); - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { sink.next(state); sink.complete(); @@ -212,7 +213,7 @@ public void mutateDefaults() { })); // @formatter:off - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) @@ -230,7 +231,7 @@ public void mutateDefaults() { assertThat(content).isEqualTo("response"); - Prompt prompt = promptCaptor.getValue(); + Prompt prompt = this.promptCaptor.getValue(); Message systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -252,7 +253,7 @@ public void mutateDefaults() { assertThat(content).isEqualTo("response"); - prompt = promptCaptor.getValue(); + prompt = this.promptCaptor.getValue(); systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -282,7 +283,7 @@ public void mutateDefaults() { assertThat(content).isEqualTo("response"); - prompt = promptCaptor.getValue(); + prompt = this.promptCaptor.getValue(); systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -304,7 +305,7 @@ public void mutateDefaults() { assertThat(content).isEqualTo("response"); - prompt = promptCaptor.getValue(); + prompt = this.promptCaptor.getValue(); systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -327,19 +328,19 @@ public void mutateDefaults() { public void mutatePrompt() { PortableFunctionCallingOptions options = new FunctionCallingOptionsBuilder().build(); - when(chatModel.getDefaultOptions()).thenReturn(options); + given(this.chatModel.getDefaultOptions()).willReturn(options); - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { sink.next(state); sink.complete(); return state; })); // @formatter:off - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) @@ -364,7 +365,7 @@ public void mutatePrompt() { assertThat(content).isEqualTo("response"); - Prompt prompt = promptCaptor.getValue(); + Prompt prompt = this.promptCaptor.getValue(); Message systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -395,7 +396,7 @@ public void mutatePrompt() { assertThat(content).isEqualTo("response"); - prompt = promptCaptor.getValue(); + prompt = this.promptCaptor.getValue(); systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -416,16 +417,16 @@ public void mutatePrompt() { @Test public void defaultUserText() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - var chatClient = ChatClient.builder(chatModel).defaultUser("Default user text").build(); + var chatClient = ChatClient.builder(this.chatModel).defaultUser("Default user text").build(); var content = chatClient.prompt().call().content(); assertThat(content).isEqualTo("response"); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualTo("Default user text"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); @@ -433,50 +434,51 @@ public void defaultUserText() { content = chatClient.prompt().user("Override default user text").call().content(); assertThat(content).isEqualTo("response"); - userMessage = promptCaptor.getValue().getInstructions().get(0); + userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualTo("Override default user text"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); } @Test public void simpleUserPromptAsString() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - assertThat(ChatClient.builder(chatModel).build().prompt("User prompt").call().content()).isEqualTo("response"); + assertThat(ChatClient.builder(this.chatModel).build().prompt("User prompt").call().content()) + .isEqualTo("response"); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualTo("User prompt"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); } @Test public void simpleUserPrompt() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - assertThat(ChatClient.builder(chatModel).build().prompt().user("User prompt").call().content()) + assertThat(ChatClient.builder(this.chatModel).build().prompt().user("User prompt").call().content()) .isEqualTo("response"); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualTo("User prompt"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); } @Test public void simpleUserPromptObject() throws MalformedURLException { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); var media = new Media(MimeTypeUtils.IMAGE_JPEG, new DefaultResourceLoader().getResource("classpath:/bikes.json")); UserMessage message = new UserMessage("User prompt", List.of(media)); Prompt prompt = new Prompt(message); - assertThat(ChatClient.builder(chatModel).build().prompt(prompt).call().content()).isEqualTo("response"); + assertThat(ChatClient.builder(this.chatModel).build().prompt(prompt).call().content()).isEqualTo("response"); - assertThat(promptCaptor.getValue().getInstructions()).hasSize(1); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(1); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getContent()).isEqualTo("User prompt"); assertThat(((UserMessage) userMessage).getMedia()).hasSize(1); @@ -484,32 +486,32 @@ public void simpleUserPromptObject() throws MalformedURLException { @Test public void simpleSystemPrompt() throws MalformedURLException { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - String response = ChatClient.builder(chatModel).build().prompt().system("System prompt").call().content(); + String response = ChatClient.builder(this.chatModel).build().prompt().system("System prompt").call().content(); assertThat(response).isEqualTo("response"); - assertThat(promptCaptor.getValue().getInstructions()).hasSize(1); + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(1); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("System prompt"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } @Test public void complexCall() throws MalformedURLException { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); var options = FunctionCallingOptions.builder().build(); - when(chatModel.getDefaultOptions()).thenReturn(options); + given(this.chatModel.getDefaultOptions()).willReturn(options); var url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); // @formatter:off - ChatClient client = ChatClient.builder(chatModel) + ChatClient client = ChatClient.builder(this.chatModel) .defaultSystem("System text") .defaultFunctions("function1") .build(); @@ -521,13 +523,13 @@ public void complexCall() throws MalformedURLException { // @formatter:on assertThat(response).isEqualTo("response"); - assertThat(promptCaptor.getValue().getInstructions()).hasSize(2); + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("System text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - UserMessage userMessage = (UserMessage) promptCaptor.getValue().getInstructions().get(1); + UserMessage userMessage = (UserMessage) this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualTo("User text Rock"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getMedia()).hasSize(1); @@ -535,7 +537,7 @@ public void complexCall() throws MalformedURLException { assertThat(userMessage.getMedia().iterator().next().getData()) .isEqualTo("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); - FunctionCallingOptions runtieOptions = (FunctionCallingOptions) promptCaptor.getValue().getOptions(); + FunctionCallingOptions runtieOptions = (FunctionCallingOptions) this.promptCaptor.getValue().getOptions(); assertThat(runtieOptions.getFunctions()).containsExactly("function1"); assertThat(options.getFunctions()).isEmpty(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java index 9a8abcfce70..21a23ad4d3a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,6 @@ package org.springframework.ai.chat.client.advisor; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; -import static org.mockito.Mockito.verify; - import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -31,11 +27,13 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.messages.AssistantMessage; @@ -44,7 +42,9 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.verify; /** * @author Christian Tzolov @@ -58,76 +58,6 @@ public class AdvisorsTests { @Captor ArgumentCaptor promptCaptor; - public class MockAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { - - public AdvisedRequest advisedRequest; - - public AdvisedResponse advisedResponse; - - public List aroundAdvisedResponses = new ArrayList<>(); - - private final String name; - - private final int order; - - public MockAroundAdvisor(String name, int order) { - this.name = name; - this.order = order; - } - - @Override - public String getName() { - return name; - } - - @Override - public int getOrder() { - return order; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - - this.advisedRequest = advisedRequest.updateContext(context -> { - context.put("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName()); - context.put("lastBefore", getName()); - return context; - }); - - AdvisedResponse advisedResponse = this.advisedResponse = chain.nextAroundCall(this.advisedRequest); - - this.advisedResponse = advisedResponse.updateContext(context -> { - context.put("aroundCallAfter" + name, "AROUND_CALL_AFTER " + name); - context.put("lastAfter", name); - return context; - }); - - return this.advisedResponse; - } - - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - - this.advisedRequest = advisedRequest.updateContext(context -> { - context.put("aroundStreamBefore" + name, "AROUND_STREAM_BEFORE " + name); - context.put("lastBefore", name); - return context; - }); - - Flux advisedResponseStream = chain.nextAroundStream(this.advisedRequest); - - return advisedResponseStream.map(advisedResponse -> { - return advisedResponse.updateContext(context -> { - context.put("aroundStreamAfter" + name, "AROUND_STREAM_AFTER " + name); - context.put("lastAfter", name); - return context; - }); - }).doOnNext(ar -> this.aroundAdvisedResponses.add(ar)); - - } - - } - @Test public void callAdvisorsContextPropagation() { @@ -136,10 +66,10 @@ public void callAdvisorsContextPropagation() { var mockAroundAdvisor1 = new MockAroundAdvisor("Advisor1", 0); var mockAroundAdvisor2 = new MockAroundAdvisor("Advisor2", 1); - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(mockAroundAdvisor1) .build(); @@ -164,7 +94,7 @@ public void callAdvisorsContextPropagation() { .containsEntry("lastBefore", "Advisor2") // inner .containsEntry("lastAfter", "Advisor1"); // outer - verify(chatModel).call(promptCaptor.capture()); + verify(this.chatModel).call(this.promptCaptor.capture()); } @Test @@ -173,11 +103,11 @@ public void streamAdvisorsContextPropagation() { var mockAroundAdvisor1 = new MockAroundAdvisor("Advisor1", 0); var mockAroundAdvisor2 = new MockAroundAdvisor("Advisor2", 1); - when(chatModel.stream(promptCaptor.capture())) - .thenReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello")))), + given(this.chatModel.stream(this.promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello")))), new ChatResponse(List.of(new Generation(new AssistantMessage(" John")))))); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(mockAroundAdvisor1) .build(); @@ -209,7 +139,78 @@ public void streamAdvisorsContextPropagation() { .containsEntry("lastAfter", "Advisor1"); // outer }); - verify(chatModel).stream(promptCaptor.capture()); + verify(this.chatModel).stream(this.promptCaptor.capture()); + } + + public class MockAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + private final String name; + + private final int order; + + public AdvisedRequest advisedRequest; + + public AdvisedResponse advisedResponse; + + public List aroundAdvisedResponses = new ArrayList<>(); + + public MockAroundAdvisor(String name, int order) { + this.name = name; + this.order = order; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + this.advisedRequest = advisedRequest.updateContext(context -> { + context.put("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName()); + context.put("lastBefore", getName()); + return context; + }); + + this.advisedResponse = chain.nextAroundCall(this.advisedRequest); + AdvisedResponse advisedResponse = this.advisedResponse; + + this.advisedResponse = advisedResponse.updateContext(context -> { + context.put("aroundCallAfter" + this.name, "AROUND_CALL_AFTER " + this.name); + context.put("lastAfter", this.name); + return context; + }); + + return this.advisedResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + this.advisedRequest = advisedRequest.updateContext(context -> { + context.put("aroundStreamBefore" + this.name, "AROUND_STREAM_BEFORE " + this.name); + context.put("lastBefore", this.name); + return context; + }); + + Flux advisedResponseStream = chain.nextAroundStream(this.advisedRequest); + + return advisedResponseStream.map(advisedResponse -> { + return advisedResponse.updateContext(context -> { + context.put("aroundStreamAfter" + this.name, "AROUND_STREAM_AFTER " + this.name); + context.put("lastAfter", this.name); + return context; + }); + }).doOnNext(ar -> this.aroundAdvisedResponses.add(ar)); + + } + } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java index e75b1d04cf6..c63fa9ec206 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.ai.chat.client.advisor; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - import java.time.Duration; import java.util.List; import java.util.Map; @@ -29,6 +26,7 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -45,6 +43,9 @@ import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; + /** * @author Christian Tzolov */ @@ -67,24 +68,24 @@ public class QuestionAnswerAdvisorTests { public void qaAdvisorWithDynamicFilterExpressions() { // @formatter:off - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))), + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))), ChatResponseMetadata.builder() .withId("678") .withModel("model1") .withKeyValue("key6", "value6") - .withMetadata(Map.of("key1","value1" )) + .withMetadata(Map.of("key1", "value1")) .withPromptMetadata(null) .withRateLimit(new RateLimit() { @Override public Long getRequestsLimit() { - return 5l; + return 5L; } @Override public Long getRequestsRemaining() { - return 6l; + return 6L; } @Override @@ -94,12 +95,12 @@ public Duration getRequestsReset() { @Override public Long getTokensLimit() { - return 8l; + return 8L; } @Override public Long getTokensRemaining() { - return 8l; + return 8L; } @Override @@ -107,17 +108,17 @@ public Duration getTokensReset() { return Duration.ofSeconds(9); } }) - .withUsage(new DefaultUsage(6l, 7l)) + .withUsage(new DefaultUsage(6L, 7L)) .build())); // @formatter:on - when(vectorStore.similaritySearch(vectorSearchCaptor.capture())) - .thenReturn(List.of(new Document("doc1"), new Document("doc2"))); + given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture())) + .willReturn(List.of(new Document("doc1"), new Document("doc2"))); - var qaAdvisor = new QuestionAnswerAdvisor(vectorStore, + var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.defaults().withSimilarityThreshold(0.99d).withTopK(6)); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(qaAdvisor) .build(); @@ -133,26 +134,23 @@ public Duration getTokensReset() { // Ensure the metadata is correctly copied over assertThat(response.getMetadata().getModel()).isEqualTo("model1"); assertThat(response.getMetadata().getId()).isEqualTo("678"); - assertThat(response.getMetadata().getRateLimit().getRequestsLimit()).isEqualTo(5l); - assertThat(response.getMetadata().getRateLimit().getRequestsRemaining()).isEqualTo(6l); + assertThat(response.getMetadata().getRateLimit().getRequestsLimit()).isEqualTo(5L); + assertThat(response.getMetadata().getRateLimit().getRequestsRemaining()).isEqualTo(6L); assertThat(response.getMetadata().getRateLimit().getRequestsReset()).isEqualTo(Duration.ofSeconds(7)); - assertThat(response.getMetadata().getRateLimit().getTokensLimit()).isEqualTo(8l); - assertThat(response.getMetadata().getRateLimit().getTokensRemaining()).isEqualTo(8l); + assertThat(response.getMetadata().getRateLimit().getTokensLimit()).isEqualTo(8L); + assertThat(response.getMetadata().getRateLimit().getTokensRemaining()).isEqualTo(8L); assertThat(response.getMetadata().getRateLimit().getTokensReset()).isEqualTo(Duration.ofSeconds(9)); - assertThat(response.getMetadata().getUsage().getPromptTokens()).isEqualTo(6l); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isEqualTo(7l); - assertThat(response.getMetadata().getUsage().getTotalTokens()).isEqualTo(6l + 7l); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isEqualTo(6L); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isEqualTo(7L); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isEqualTo(6L + 7L); assertThat(response.getMetadata().get("key6").toString()).isEqualTo("value6"); assertThat(response.getMetadata().get("key1").toString()).isEqualTo("value1"); - - - String content = response.getResult().getOutput().getContent(); assertThat(content).isEqualTo("Your answer is ZXY"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); System.out.println(systemMessage.getContent()); @@ -161,7 +159,7 @@ public Duration getTokensReset() { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - Message userMessage = promptCaptor.getValue().getInstructions().get(1); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace(""" Please answer my question XYZ @@ -177,9 +175,9 @@ public Duration getTokensReset() { the user that you can't answer the question. """); - assertThat(vectorSearchCaptor.getValue().getFilterExpression()).isEqualTo(new FilterExpressionBuilder().eq("type", "Spring").build()); - assertThat(vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d); - assertThat(vectorSearchCaptor.getValue().getTopK()).isEqualTo(6); + assertThat(this.vectorSearchCaptor.getValue().getFilterExpression()).isEqualTo(new FilterExpressionBuilder().eq("type", "Spring").build()); + assertThat(this.vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d); + assertThat(this.vectorSearchCaptor.getValue().getTopK()).isEqualTo(6); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java index ee864a22d70..b12e1b24f3a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.ai.chat.client.advisor; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - import java.util.List; import java.util.stream.Collectors; @@ -28,6 +25,8 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -39,7 +38,8 @@ import org.springframework.boot.test.system.OutputCaptureExtension; import org.springframework.test.context.ActiveProfiles; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -57,12 +57,12 @@ public class SimpleLoggerAdvisorTests { @Test public void callLogging(CapturedOutput output) { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))))); var loggerAdvisor = new SimpleLoggerAdvisor(); - var chatClient = ChatClient.builder(chatModel).defaultAdvisors(loggerAdvisor).build(); + var chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(loggerAdvisor).build(); var content = chatClient.prompt().user("Please answer my question XYZ").call().content(); @@ -72,7 +72,7 @@ public void callLogging(CapturedOutput output) { @Test public void streamLogging(CapturedOutput output) { - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY")))), (state, sink) -> { sink.next(state); @@ -82,7 +82,7 @@ public void streamLogging(CapturedOutput output) { var loggerAdvisor = new SimpleLoggerAdvisor(); - var chatClient = ChatClient.builder(chatModel).defaultAdvisors(loggerAdvisor).build(); + var chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(loggerAdvisor).build(); String content = join(chatClient.prompt().user("Please answer my question XYZ").stream().content()); @@ -100,7 +100,7 @@ public void loggingOrder() { private void validate(String content, CapturedOutput output) { assertThat(content).isEqualTo("Your answer is ZXY"); - UserMessage userMessage = (UserMessage) promptCaptor.getValue().getInstructions().get(0); + UserMessage userMessage = (UserMessage) this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("Please answer my question XYZ"); assertThat(output.getOut()).contains("request: AdvisedRequest", "userText=Please answer my question XYZ"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java index 795d1056faa..9e12573fad0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client.advisor.observation; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import org.junit.jupiter.api.Test; - /** * Unit tests for {@link AdvisorObservationContext}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java index 7f9f4e3da02..906a3376873 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.client.advisor.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat.client.advisor.observation; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.LowCardinalityKeyNames; - -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; import org.springframework.ai.observation.conventions.SpringAiKind; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit tests for {@link DefaultAdvisorObservationConvention}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java index 2b189fdf797..31d017d749e 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.client.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat.client.observation; import java.util.List; import java.util.Map; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.model.ChatModel; -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link ChatClientInputContentObservationFilter}. @@ -43,29 +44,29 @@ class ChatClientInputContentObservationFilterTests { private final ChatClientInputContentObservationFilter observationFilter = new ChatClientInputContentObservationFilter(); + @Mock + ChatModel chatModel; + @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } - @Mock - ChatModel chatModel; - @Test void whenEmptyInputContentThenReturnOriginalContext() { ObservationRegistry observationRegistry = ObservationRegistry.NOOP; ChatClientObservationConvention customObservationConvention = null; - var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), + var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); var expectedContext = ChatClientObservationContext.builder().withRequest(request).build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -75,13 +76,13 @@ void whenWithTextThenAugmentContext() { ObservationRegistry observationRegistry = ObservationRegistry.NOOP; ChatClientObservationConvention customObservationConvention = null; - var request = new DefaultChatClientRequestSpec(chatModel, "sample user text", Map.of("up1", "upv1"), + var request = new DefaultChatClientRequestSpec(this.chatModel, "sample user text", Map.of("up1", "upv1"), "sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); var originalContext = ChatClientObservationContext.builder().withRequest(request).build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()) .contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT.asString(), "sample user text")); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java index 0cc401e8735..cf8f644248a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.client.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat.client.observation; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.model.ChatModel; -import io.micrometer.observation.ObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link ChatClientObservationContext}. @@ -44,7 +45,7 @@ class ChatClientObservationContextTests { @Test void whenMandatoryRequestOptionsThenReturn() { - var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), + var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 5809cc19dd3..0f1e4814277 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.client.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat.client.observation; import java.util.List; import java.util.Map; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; @@ -36,9 +39,7 @@ import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.observation.conventions.SpringAiKind; -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link DefaultChatClientObservationConvention}. @@ -49,60 +50,16 @@ @ExtendWith(MockitoExtension.class) class DefaultChatClientObservationConventionTests { + private final DefaultChatClientObservationConvention observationConvention = new DefaultChatClientObservationConvention(); + @Mock ChatModel chatModel; - private final DefaultChatClientObservationConvention observationConvention = new DefaultChatClientObservationConvention(); - DefaultChatClientRequestSpec request; - @BeforeEach - public void beforeEach() { - request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); - } - - @Test - void shouldHaveName() { - assertThat(this.observationConvention.getName()).isEqualTo(DefaultChatClientObservationConvention.DEFAULT_NAME); - } - - @Test - void shouldHaveContextualName() { - ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(request) - .withStream(true) - .build(); - - assertThat(this.observationConvention.getContextualName(observationContext)) - .isEqualTo("%s %s".formatted(AiProvider.SPRING_AI.value(), SpringAiKind.CHAT_CLIENT.value())); - } - - @Test - void supportsOnlyChatClientObservationContext() { - ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(request) - .withStream(true) - .build(); - - assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); - assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); - } - - @Test - void shouldHaveRequiredKeyValues() { - ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(request) - .withStream(true) - .build(); - - assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( - KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), "chat_client"), - KeyValue.of(LowCardinalityKeyNames.STREAM.asString(), "true")); - } - static RequestResponseAdvisor dummyAdvisor(String name) { return new RequestResponseAdvisor() { + @Override public String getName() { return name; @@ -128,6 +85,7 @@ public ChatResponse adviseResponse(ChatResponse response, Map ad static FunctionCallback dummyFunction(String name) { return new FunctionCallback() { + @Override public String getName() { return name; @@ -153,9 +111,54 @@ public String call(String functionInput) { }; } + @BeforeEach + public void beforeEach() { + this.request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), + List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); + } + + @Test + void shouldHaveName() { + assertThat(this.observationConvention.getName()).isEqualTo(DefaultChatClientObservationConvention.DEFAULT_NAME); + } + + @Test + void shouldHaveContextualName() { + ChatClientObservationContext observationContext = ChatClientObservationContext.builder() + .withRequest(this.request) + .withStream(true) + .build(); + + assertThat(this.observationConvention.getContextualName(observationContext)) + .isEqualTo("%s %s".formatted(AiProvider.SPRING_AI.value(), SpringAiKind.CHAT_CLIENT.value())); + } + + @Test + void supportsOnlyChatClientObservationContext() { + ChatClientObservationContext observationContext = ChatClientObservationContext.builder() + .withRequest(this.request) + .withStream(true) + .build(); + + assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); + assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); + } + + @Test + void shouldHaveRequiredKeyValues() { + ChatClientObservationContext observationContext = ChatClientObservationContext.builder() + .withRequest(this.request) + .withStream(true) + .build(); + + assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( + KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), "chat_client"), + KeyValue.of(LowCardinalityKeyNames.STREAM.asString(), "true")); + } + @Test void shouldHaveOptionalKeyValues() { - var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), + var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(), List.of("function1", "function2"), List.of(), null, List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"), diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java index 68985f67bd6..8059faf9580 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; + +import static org.junit.jupiter.api.Assertions.assertEquals; public class DefaultUsageTests { @@ -26,14 +28,14 @@ public class DefaultUsageTests { @Test void testSerializationWithAllFields() throws Exception { DefaultUsage usage = new DefaultUsage(100L, 50L, 150L); - String json = objectMapper.writeValueAsString(usage); + String json = this.objectMapper.writeValueAsString(usage); assertEquals("{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}", json); } @Test void testDeserializationWithAllFields() throws Exception { String json = "{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}"; - DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(100L, usage.getPromptTokens()); assertEquals(50L, usage.getGenerationTokens()); assertEquals(150L, usage.getTotalTokens()); @@ -42,14 +44,14 @@ void testDeserializationWithAllFields() throws Exception { @Test void testSerializationWithNullFields() throws Exception { DefaultUsage usage = new DefaultUsage(null, null, null); - String json = objectMapper.writeValueAsString(usage); + String json = this.objectMapper.writeValueAsString(usage); assertEquals("{\"promptTokens\":0,\"generationTokens\":0,\"totalTokens\":0}", json); } @Test void testDeserializationWithMissingFields() throws Exception { String json = "{\"promptTokens\":100}"; - DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(100L, usage.getPromptTokens()); assertEquals(0L, usage.getGenerationTokens()); assertEquals(100L, usage.getTotalTokens()); @@ -58,7 +60,7 @@ void testDeserializationWithMissingFields() throws Exception { @Test void testDeserializationWithNullFields() throws Exception { String json = "{\"promptTokens\":null,\"generationTokens\":null,\"totalTokens\":null}"; - DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(0L, usage.getPromptTokens()); assertEquals(0L, usage.getGenerationTokens()); assertEquals(0L, usage.getTotalTokens()); @@ -67,8 +69,8 @@ void testDeserializationWithNullFields() throws Exception { @Test void testRoundTripSerialization() throws Exception { DefaultUsage original = new DefaultUsage(100L, 50L, 150L); - String json = objectMapper.writeValueAsString(original); - DefaultUsage deserialized = objectMapper.readValue(json, DefaultUsage.class); + String json = this.objectMapper.writeValueAsString(original); + DefaultUsage deserialized = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(original.getPromptTokens(), deserialized.getPromptTokens()); assertEquals(original.getGenerationTokens(), deserialized.getGenerationTokens()); assertEquals(original.getTotalTokens(), deserialized.getTotalTokens()); @@ -84,11 +86,11 @@ void testTwoArgumentConstructorAndSerialization() throws Exception { assertEquals(150L, usage.getTotalTokens()); // 100 + 50 = 150 // Test serialization - String json = objectMapper.writeValueAsString(usage); + String json = this.objectMapper.writeValueAsString(usage); assertEquals("{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}", json); // Test deserialization - DefaultUsage deserializedUsage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(100L, deserializedUsage.getPromptTokens()); assertEquals(50L, deserializedUsage.getGenerationTokens()); assertEquals(150L, deserializedUsage.getTotalTokens()); @@ -104,11 +106,11 @@ void testTwoArgumentConstructorWithNullValues() throws Exception { assertEquals(0L, usage.getTotalTokens()); // Test serialization - String json = objectMapper.writeValueAsString(usage); + String json = this.objectMapper.writeValueAsString(usage); assertEquals("{\"promptTokens\":0,\"generationTokens\":0,\"totalTokens\":0}", json); // Test deserialization - DefaultUsage deserializedUsage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(0L, deserializedUsage.getPromptTokens()); assertEquals(0L, deserializedUsage.getGenerationTokens()); assertEquals(0L, deserializedUsage.getTotalTokens()); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/model/GenerationTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/model/GenerationTests.java index 4bcf3344eec..b5e173e1d4b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/model/GenerationTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/model/GenerationTests.java @@ -1,9 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.model; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; @@ -41,9 +58,9 @@ void testGetOutput() { @Test void testConstructorWithMetadata() { AssistantMessage assistantMessage = new AssistantMessage("Test Assistant Message"); - Generation generation = new Generation(assistantMessage, mockChatGenerationMetadata1); + Generation generation = new Generation(assistantMessage, this.mockChatGenerationMetadata1); - assertEquals(mockChatGenerationMetadata1, generation.getMetadata()); + assertEquals(this.mockChatGenerationMetadata1, generation.getMetadata()); } @Test @@ -58,10 +75,10 @@ void testGetMetadata_Null() { @Test void testGetMetadata_NotNull() { AssistantMessage assistantMessage = new AssistantMessage("Test Assistant Message"); - Generation generation = new Generation(assistantMessage, mockChatGenerationMetadata1); + Generation generation = new Generation(assistantMessage, this.mockChatGenerationMetadata1); ChatGenerationMetadata metadata = generation.getMetadata(); - assertEquals(mockChatGenerationMetadata1, metadata); + assertEquals(this.mockChatGenerationMetadata1, metadata); } @Test @@ -86,8 +103,8 @@ void testEquals_NotInstanceOfGeneration() { void testEquals_SameMetadata() { AssistantMessage assistantMessage1 = new AssistantMessage("Test Assistant Message"); AssistantMessage assistantMessage2 = new AssistantMessage("Test Assistant Message"); - Generation generation1 = new Generation(assistantMessage1, mockChatGenerationMetadata1); - Generation generation2 = new Generation(assistantMessage2, mockChatGenerationMetadata1); + Generation generation1 = new Generation(assistantMessage1, this.mockChatGenerationMetadata1); + Generation generation2 = new Generation(assistantMessage2, this.mockChatGenerationMetadata1); assertTrue(generation1.equals(generation2)); } @@ -96,8 +113,8 @@ void testEquals_SameMetadata() { void testEquals_DifferentMetadata() { AssistantMessage assistantMessage1 = new AssistantMessage("Test Assistant Message"); AssistantMessage assistantMessage2 = new AssistantMessage("Test Assistant Message"); - Generation generation1 = new Generation(assistantMessage1, mockChatGenerationMetadata1); - Generation generation2 = new Generation(assistantMessage2, mockChatGenerationMetadata2); + Generation generation1 = new Generation(assistantMessage1, this.mockChatGenerationMetadata1); + Generation generation2 = new Generation(assistantMessage2, this.mockChatGenerationMetadata2); assertFalse(generation1.equals(generation2)); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java index 0276568dd1c..2ee37ac283d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -41,7 +43,7 @@ class ChatModelCompletionObservationFilterTests { @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -53,7 +55,7 @@ void whenEmptyResponseThenReturnOriginalContext() { .provider("superprovider") .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -66,7 +68,7 @@ void whenEmptyCompletionThenReturnOriginalContext() { .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); expectedContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage(""))))); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -80,7 +82,7 @@ void whenCompletionWithTextThenAugmentContext() { .build(); originalContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("say please")), new Generation(new AssistantMessage("seriously, say please"))))); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()).contains(KeyValue .of(HighCardinalityKeyNames.COMPLETION.asString(), "[\"say please\", \"seriously, say please\"]")); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java index f5b12e5536c..225fcbce50c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.tracing.handler.TracingObservationHandler; import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; @@ -22,6 +25,7 @@ import io.opentelemetry.sdk.trace.ReadableSpan; import io.opentelemetry.sdk.trace.SdkTracerProvider; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -31,8 +35,6 @@ import org.springframework.ai.observation.conventions.AiObservationEventNames; import org.springframework.ai.observation.tracing.TracingHelper; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java index 08edea5d18c..cf097fc89f5 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.Usage; @@ -28,9 +32,10 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.observation.conventions.*; - -import java.util.List; +import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; +import org.springframework.ai.observation.conventions.AiObservationMetricNames; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiTokenType; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; @@ -59,7 +64,7 @@ void shouldCreateAllMetersDuringAnObservation() { var observationContext = generateObservationContext(); var observation = Observation .createNotStarted(new DefaultChatModelObservationConvention(), () -> observationContext, - observationRegistry) + this.observationRegistry) .start(); observationContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("test"))), @@ -67,20 +72,20 @@ void shouldCreateAllMetersDuringAnObservation() { observation.stop(); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.CHAT.value()) .tag(LowCardinalityKeyNames.AI_PROVIDER.asString(), "superprovider") .tag(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "mistral") .tag(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), "mistral-42") .meters()).hasSize(3); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.INPUT.value()) .meters()).hasSize(1); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.OUTPUT.value()) .meters()).hasSize(1); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.TOTAL.value()) .meters()).hasSize(1); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java index e723b91263c..a7c62a462b9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java index 92d9e0d8b43..8e33c73e0c1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -40,7 +42,7 @@ class ChatModelPromptContentObservationFilterTests { @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -52,7 +54,7 @@ void whenEmptyPromptThenReturnOriginalContext() { .provider("superprovider") .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -64,7 +66,7 @@ void whenPromptWithTextThenAugmentContext() { .provider("superprovider") .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()).contains( KeyValue.of(HighCardinalityKeyNames.PROMPT.asString(), "[\"supercalifragilisticexpialidocious\"]")); @@ -78,7 +80,7 @@ void whenPromptWithMessagesThenAugmentContext() { .provider("superprovider") .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()) .contains(KeyValue.of(HighCardinalityKeyNames.PROMPT.asString(), diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java index 375598244a6..4064d9570da 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import io.micrometer.tracing.handler.TracingObservationHandler; @@ -22,6 +23,7 @@ import io.opentelemetry.sdk.trace.ReadableSpan; import io.opentelemetry.sdk.trace.SdkTracerProvider; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.observation.conventions.AiObservationAttributes; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java index 42164fa9463..929637d788f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -27,8 +31,6 @@ import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java b/spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java index 14627862d2b..0207439580a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import java.time.LocalDate; @@ -44,41 +45,95 @@ class BeanOutputConverterTest { private ObjectMapper objectMapperMock; @Test - public void shouldHavePreConfiguredDefaultObjectMapper() { + void shouldHavePreConfiguredDefaultObjectMapper() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); var objectMapper = converter.getObjectMapper(); assertThat(objectMapper.isEnabled(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)).isFalse(); } + static class TestClass { + + private String someString; + + @SuppressWarnings("unused") + TestClass() { + } + + TestClass(String someString) { + this.someString = someString; + } + + String getSomeString() { + return this.someString; + } + + } + + static class TestClassWithDateProperty { + + private LocalDate someString; + + @SuppressWarnings("unused") + TestClassWithDateProperty() { + } + + TestClassWithDateProperty(LocalDate someString) { + this.someString = someString; + } + + LocalDate getSomeString() { + return this.someString; + } + + } + + static class TestClassWithJsonAnnotations { + + @JsonProperty("string_property") + @JsonPropertyDescription("string_property_description") + private String someString; + + TestClassWithJsonAnnotations() { + } + + String getSomeString() { + return this.someString; + } + + } + @Nested class ConverterTest { @Test - public void convertClassType() { + void convertClassType() { var converter = new BeanOutputConverter<>(TestClass.class); var testClass = converter.convert("{ \"someString\": \"some value\" }"); assertThat(testClass.getSomeString()).isEqualTo("some value"); } @Test - public void convertClassWithDateType() { + void convertClassWithDateType() { var converter = new BeanOutputConverter<>(TestClassWithDateProperty.class); var testClass = converter.convert("{ \"someString\": \"2020-01-01\" }"); assertThat(testClass.getSomeString()).isEqualTo(LocalDate.of(2020, 1, 1)); } @Test - public void convertTypeReference() { + void convertTypeReference() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); var testClass = converter.convert("{ \"someString\": \"some value\" }"); assertThat(testClass.getSomeString()).isEqualTo("some value"); } @Test - public void convertTypeReferenceArray() { + void convertTypeReferenceArray() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference>() { + }); List testClass = converter.convert("[{ \"someString\": \"some value\" }]"); assertThat(testClass).hasSize(1); @@ -86,24 +141,26 @@ public void convertTypeReferenceArray() { } @Test - public void convertClassTypeWithJsonAnnotations() { + void convertClassTypeWithJsonAnnotations() { var converter = new BeanOutputConverter<>(TestClassWithJsonAnnotations.class); var testClass = converter.convert("{ \"string_property\": \"some value\" }"); assertThat(testClass.getSomeString()).isEqualTo("some value"); } @Test - public void convertTypeReferenceWithJsonAnnotations() { + void convertTypeReferenceWithJsonAnnotations() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); var testClass = converter.convert("{ \"string_property\": \"some value\" }"); assertThat(testClass.getSomeString()).isEqualTo("some value"); } @Test - public void convertTypeReferenceArrayWithJsonAnnotations() { + void convertTypeReferenceArrayWithJsonAnnotations() { var converter = new BeanOutputConverter<>( new ParameterizedTypeReference>() { + }); List testClass = converter .convert("[{ \"string_property\": \"some value\" }]"); @@ -113,11 +170,12 @@ public void convertTypeReferenceArrayWithJsonAnnotations() { } + // @checkstyle:off RegexpSinglelineJavaCheck @Nested class FormatTest { @Test - public void formatClassType() { + void formatClassType() { var converter = new BeanOutputConverter<>(TestClass.class); assertThat(converter.getFormat()).isEqualTo( """ @@ -140,8 +198,9 @@ public void formatClassType() { } @Test - public void formatTypeReference() { + void formatTypeReference() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); assertThat(converter.getFormat()).isEqualTo( """ @@ -164,8 +223,9 @@ public void formatTypeReference() { } @Test - public void formatTypeReferenceArray() { + void formatTypeReferenceArray() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference>() { + }); assertThat(converter.getFormat()).isEqualTo( """ @@ -191,7 +251,7 @@ public void formatTypeReferenceArray() { } @Test - public void formatClassTypeWithAnnotations() { + void formatClassTypeWithAnnotations() { var converter = new BeanOutputConverter<>(TestClassWithJsonAnnotations.class); assertThat(converter.getFormat()).contains(""" ```{ @@ -209,8 +269,9 @@ public void formatClassTypeWithAnnotations() { } @Test - public void formatTypeReferenceWithAnnotations() { + void formatTypeReferenceWithAnnotations() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); assertThat(converter.getFormat()).contains(""" ```{ @@ -226,6 +287,7 @@ public void formatTypeReferenceWithAnnotations() { }``` """); } + // @checkstyle:on RegexpSinglelineJavaCheck @Test void normalizesLineEndingsClassType() { @@ -240,6 +302,7 @@ void normalizesLineEndingsClassType() { @Test void normalizesLineEndingsTypeReference() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); String formatOutput = converter.getFormat(); @@ -250,55 +313,4 @@ void normalizesLineEndingsTypeReference() { } - public static class TestClass { - - private String someString; - - @SuppressWarnings("unused") - public TestClass() { - } - - public TestClass(String someString) { - this.someString = someString; - } - - public String getSomeString() { - return someString; - } - - } - - public static class TestClassWithDateProperty { - - private LocalDate someString; - - @SuppressWarnings("unused") - public TestClassWithDateProperty() { - } - - public TestClassWithDateProperty(LocalDate someString) { - this.someString = someString; - } - - public LocalDate getSomeString() { - return someString; - } - - } - - public static class TestClassWithJsonAnnotations { - - @JsonProperty("string_property") - @JsonPropertyDescription("string_property_description") - private String someString; - - public TestClassWithJsonAnnotations() { - } - - public String getSomeString() { - return someString; - } - - } - -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java b/spring-ai-core/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java index 4c0795a1c0a..f63f1e2fe6d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import java.util.List; @@ -33,4 +34,4 @@ void csv() { assertThat(list).containsExactlyElementsOf(List.of("foo", "bar", "baz")); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/ContentFormatterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/document/ContentFormatterTests.java index b20d7595b40..5437d965cff 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/ContentFormatterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/ContentFormatterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.Map; @@ -31,12 +32,12 @@ public class ContentFormatterTests { @Test public void noExplicitlySetFormatter() { - assertThat(document.getContent()).isEqualTo(""" + assertThat(this.document.getContent()).isEqualTo(""" The World is Big and Salvation Lurks Around the Corner"""); - assertThat(document.getFormattedContent()).isEqualTo(document.getFormattedContent(MetadataMode.ALL)); - assertThat(document.getFormattedContent()) - .isEqualTo(document.getFormattedContent(Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.ALL)); + assertThat(this.document.getFormattedContent()).isEqualTo(this.document.getFormattedContent(MetadataMode.ALL)); + assertThat(this.document.getFormattedContent()) + .isEqualTo(this.document.getFormattedContent(Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.ALL)); } @@ -45,7 +46,7 @@ public void defaultConfigTextFormatter() { DefaultContentFormatter defaultConfigFormatter = DefaultContentFormatter.defaultConfig(); - assertThat(document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)).isEqualTo(""" + assertThat(this.document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)).isEqualTo(""" llmKey2: value4 embedKey1: value1 embedKey2: value2 @@ -53,11 +54,11 @@ public void defaultConfigTextFormatter() { The World is Big and Salvation Lurks Around the Corner"""); - assertThat(document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)) - .isEqualTo(document.getFormattedContent()); + assertThat(this.document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)) + .isEqualTo(this.document.getFormattedContent()); - assertThat(document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)) - .isEqualTo(defaultConfigFormatter.format(document, MetadataMode.ALL)); + assertThat(this.document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)) + .isEqualTo(defaultConfigFormatter.format(this.document, MetadataMode.ALL)); } @Test @@ -70,23 +71,24 @@ public void customTextFormatter() { .withMetadataTemplate("Key/Value {key}={value}") .build(); - assertThat(document.getFormattedContent(textFormatter, MetadataMode.EMBED)).isEqualTo(""" + assertThat(this.document.getFormattedContent(textFormatter, MetadataMode.EMBED)).isEqualTo(""" Metadata: Key/Value llmKey2=value4 Key/Value embedKey1=value1 Text:The World is Big and Salvation Lurks Around the Corner"""); - assertThat(document.getContent()).isEqualTo(""" + assertThat(this.document.getContent()).isEqualTo(""" The World is Big and Salvation Lurks Around the Corner"""); - assertThat(document.getFormattedContent(textFormatter, MetadataMode.EMBED)) - .isEqualTo(textFormatter.format(document, MetadataMode.EMBED)); + assertThat(this.document.getFormattedContent(textFormatter, MetadataMode.EMBED)) + .isEqualTo(textFormatter.format(this.document, MetadataMode.EMBED)); - var documentWithCustomFormatter = new Document(document.getId(), document.getContent(), document.getMetadata()); + var documentWithCustomFormatter = new Document(this.document.getId(), this.document.getContent(), + this.document.getMetadata()); documentWithCustomFormatter.setContentFormatter(textFormatter); - assertThat(document.getFormattedContent(textFormatter, MetadataMode.ALL)) + assertThat(this.document.getFormattedContent(textFormatter, MetadataMode.ALL)) .isEqualTo(documentWithCustomFormatter.getFormattedContent()); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java index ee745d9a1ec..ebaeef38905 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.document; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.springframework.ai.model.Media; -import org.springframework.ai.document.id.IdGenerator; -import org.springframework.util.MimeTypeUtils; +package org.springframework.ai.document; import java.net.MalformedURLException; import java.net.URL; @@ -27,6 +22,13 @@ import java.util.List; import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.document.id.IdGenerator; +import org.springframework.ai.model.Media; +import org.springframework.util.MimeTypeUtils; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -34,23 +36,39 @@ public class DocumentBuilderTests { private Document.Builder builder; + private static List getMediaList() { + try { + URL mediaUrl1 = new URL("http://type1"); + URL mediaUrl2 = new URL("http://type2"); + Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1); + Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2); + List mediaList = List.of(media1, media2); + return mediaList; + } + catch (MalformedURLException e) { + throw new RuntimeException(e); + } + + } + @BeforeEach void setUp() { - builder = Document.builder(); + this.builder = Document.builder(); } @Test void testWithIdGenerator() { IdGenerator mockGenerator = new IdGenerator() { + @Override public String generateId(Object... contents) { return "mockedId"; } }; - Document.Builder result = builder.withIdGenerator(mockGenerator); + Document.Builder result = this.builder.withIdGenerator(mockGenerator); - assertThat(result).isSameAs(builder); + assertThat(result).isSameAs(this.builder); Document document = result.withContent("Test content").withMetadata("key", "value").build(); @@ -59,53 +77,54 @@ public String generateId(Object... contents) { @Test void testWithIdGeneratorNull() { - assertThatThrownBy(() -> builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> this.builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("idGenerator must not be null"); } @Test void testWithId() { - Document.Builder result = builder.withId("testId"); + Document.Builder result = this.builder.withId("testId"); - assertThat(result).isSameAs(builder); + assertThat(result).isSameAs(this.builder); assertThat(result.build().getId()).isEqualTo("testId"); } @Test void testWithIdNullOrEmpty() { - assertThatThrownBy(() -> builder.withId(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> this.builder.withId(null)).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("id must not be null or empty"); - assertThatThrownBy(() -> builder.withId("")).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> this.builder.withId("")).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("id must not be null or empty"); } @Test void testWithContent() { - Document.Builder result = builder.withContent("Test content"); + Document.Builder result = this.builder.withContent("Test content"); - assertThat(result).isSameAs(builder); + assertThat(result).isSameAs(this.builder); assertThat(result.build().getContent()).isEqualTo("Test content"); } @Test void testWithContentNull() { - assertThatThrownBy(() -> builder.withContent(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> this.builder.withContent(null)).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("content must not be null"); } @Test void testWithMediaList() { List mediaList = getMediaList(); - Document.Builder result = builder.withMedia(mediaList); + Document.Builder result = this.builder.withMedia(mediaList); - assertThat(result).isSameAs(builder); + assertThat(result).isSameAs(this.builder); assertThat(result.build().getMedia()).isEqualTo(mediaList); } @Test void testWithMediaListNull() { - assertThatThrownBy(() -> builder.withMedia((List) null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> this.builder.withMedia((List) null)) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("media must not be null"); } @@ -114,15 +133,15 @@ void testWithMediaSingle() throws MalformedURLException { URL mediaUrl = new URL("http://test"); Media media = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl); - Document.Builder result = builder.withMedia(media); + Document.Builder result = this.builder.withMedia(media); - assertThat(result).isSameAs(builder); + assertThat(result).isSameAs(this.builder); assertThat(result.build().getMedia()).contains(media); } @Test void testWithMediaSingleNull() { - assertThatThrownBy(() -> builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> this.builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("media must not be null"); } @@ -131,39 +150,39 @@ void testWithMetadataMap() { Map metadata = new HashMap<>(); metadata.put("key1", "value1"); metadata.put("key2", 2); - Document.Builder result = builder.withMetadata(metadata); + Document.Builder result = this.builder.withMetadata(metadata); - assertThat(result).isSameAs(builder); + assertThat(result).isSameAs(this.builder); assertThat(result.build().getMetadata()).isEqualTo(metadata); } @Test void testWithMetadataMapNull() { - assertThatThrownBy(() -> builder.withMetadata((Map) null)) + assertThatThrownBy(() -> this.builder.withMetadata((Map) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("metadata must not be null"); } @Test void testWithMetadataKeyValue() { - Document.Builder result = builder.withMetadata("key", "value"); + Document.Builder result = this.builder.withMetadata("key", "value"); - assertThat(result).isSameAs(builder); + assertThat(result).isSameAs(this.builder); assertThat(result.build().getMetadata()).containsEntry("key", "value"); } @Test void testWithMetadataKeyValueNull() { - assertThatThrownBy(() -> builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> this.builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("key must not be null"); - assertThatThrownBy(() -> builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> this.builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("value must not be null"); } @Test void testBuildWithoutId() { - Document document = builder.withContent("Test content").build(); + Document document = this.builder.withContent("Test content").build(); assertThat(document.getId()).isNotNull().isNotEmpty(); assertThat(document.getContent()).isEqualTo("Test content"); @@ -176,7 +195,7 @@ void testBuildWithAllProperties() throws MalformedURLException { Map metadata = new HashMap<>(); metadata.put("key", "value"); - Document document = builder.withId("customId") + Document document = this.builder.withId("customId") .withContent("Test content") .withMedia(mediaList) .withMetadata(metadata) @@ -188,19 +207,4 @@ void testBuildWithAllProperties() throws MalformedURLException { assertThat(document.getMetadata()).isEqualTo(metadata); } - private static List getMediaList() { - try { - URL mediaUrl1 = new URL("http://type1"); - URL mediaUrl2 = new URL("http://type2"); - Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1); - Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2); - List mediaList = List.of(media1, media2); - return mediaList; - } - catch (MalformedURLException e) { - throw new RuntimeException(e); - } - - } - } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/id/IdGeneratorProviderTest.java b/spring-ai-core/src/test/java/org/springframework/ai/document/id/IdGeneratorProviderTest.java index 2e74d671bfe..5072e51cf03 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/id/IdGeneratorProviderTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/id/IdGeneratorProviderTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; import java.util.Map; @@ -64,4 +65,4 @@ void hashGeneratorGenerateDifferentIdsForDifferentContent() { Assertions.assertDoesNotThrow(() -> UUID.fromString(actualHashes2)); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/id/JdkSha256HexIdGeneratorTest.java b/spring-ai-core/src/test/java/org/springframework/ai/document/id/JdkSha256HexIdGeneratorTest.java index 4fc94f62ee9..6d610fa2ade 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/id/JdkSha256HexIdGeneratorTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/id/JdkSha256HexIdGeneratorTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; import java.nio.charset.Charset; @@ -28,8 +29,8 @@ public class JdkSha256HexIdGeneratorTest { @Test void messageDigestReturnsDistinctInstances() { - final MessageDigest md1 = testee.getMessageDigest(); - final MessageDigest md2 = testee.getMessageDigest(); + final MessageDigest md1 = this.testee.getMessageDigest(); + final MessageDigest md2 = this.testee.getMessageDigest(); Assertions.assertThat(md1 != md2).isTrue(); @@ -45,10 +46,10 @@ void messageDigestReturnsInstancesWithIndependentAndReproducibleDigests() { final String updateString2 = "md2_update"; final Charset charset = StandardCharsets.UTF_8; - final byte[] md1BytesFirstTry = testee.getMessageDigest().digest(updateString1.getBytes(charset)); - final byte[] md2BytesFirstTry = testee.getMessageDigest().digest(updateString2.getBytes(charset)); - final byte[] md1BytesSecondTry = testee.getMessageDigest().digest(updateString1.getBytes(charset)); - final byte[] md2BytesSecondTry = testee.getMessageDigest().digest(updateString2.getBytes(charset)); + final byte[] md1BytesFirstTry = this.testee.getMessageDigest().digest(updateString1.getBytes(charset)); + final byte[] md2BytesFirstTry = this.testee.getMessageDigest().digest(updateString2.getBytes(charset)); + final byte[] md1BytesSecondTry = this.testee.getMessageDigest().digest(updateString1.getBytes(charset)); + final byte[] md2BytesSecondTry = this.testee.getMessageDigest().digest(updateString2.getBytes(charset)); Assertions.assertThat(md1BytesFirstTry).isNotEqualTo(md2BytesFirstTry); @@ -56,4 +57,4 @@ void messageDigestReturnsInstancesWithIndependentAndReproducibleDigests() { Assertions.assertThat(md2BytesFirstTry).isEqualTo(md2BytesSecondTry); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingModelTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingModelTests.java index 88ff94632d5..c64a3cedfe0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingModelTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingModelTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.List; @@ -29,9 +30,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Christian Tzolov @@ -79,16 +80,17 @@ public EmbeddingResponse call(EmbeddingRequest request) { @ParameterizedTest @CsvFileSource(resources = "/embedding/embedding-model-dimensions.properties", numLinesToSkip = 1, delimiter = '=') public void testKnownEmbeddingModelDimensions(String model, String dimension) { - assertThat(AbstractEmbeddingModel.dimensions(embeddingModel, model, "Hello world!")) + assertThat(AbstractEmbeddingModel.dimensions(this.embeddingModel, model, "Hello world!")) .isEqualTo(Integer.valueOf(dimension)); - verify(embeddingModel, never()).embed(any(String.class)); - verify(embeddingModel, never()).embed(any(Document.class)); + verify(this.embeddingModel, never()).embed(any(String.class)); + verify(this.embeddingModel, never()).embed(any(Document.class)); } @Test public void testUnknownModelDimension() { - when(embeddingModel.embed(eq("Hello world!"))).thenReturn(new float[] { 0.1f, 0.1f, 0.1f }); - assertThat(AbstractEmbeddingModel.dimensions(embeddingModel, "unknown_model", "Hello world!")).isEqualTo(3); + given(this.embeddingModel.embed(eq("Hello world!"))).willReturn(new float[] { 0.1f, 0.1f, 0.1f }); + assertThat(AbstractEmbeddingModel.dimensions(this.embeddingModel, "unknown_model", "Hello world!")) + .isEqualTo(3); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java index f809ccf27d2..e3afc398426 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.embedding; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +package org.springframework.ai.embedding; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -28,6 +26,9 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + /** * Basic unit test for {@link TokenCountBatchingStrategy}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java index f973d95eb49..977c30a443a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; +import java.util.List; +import java.util.Map; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java index 560f37a55b8..a97afb9d1d3 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,23 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; +import java.util.List; +import java.util.Map; + import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; -import org.springframework.ai.observation.conventions.*; - -import java.util.List; -import java.util.Map; +import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; +import org.springframework.ai.observation.conventions.AiObservationMetricNames; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiTokenType; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -58,7 +63,7 @@ void shouldCreateAllMetersDuringAnObservation() { var observationContext = generateObservationContext(); var observation = Observation .createNotStarted(new DefaultEmbeddingModelObservationConvention(), () -> observationContext, - observationRegistry) + this.observationRegistry) .start(); observationContext.setResponse(new EmbeddingResponse(List.of(), @@ -66,20 +71,20 @@ void shouldCreateAllMetersDuringAnObservation() { observation.stop(); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.EMBEDDING.value()) .tag(LowCardinalityKeyNames.AI_PROVIDER.asString(), "superprovider") .tag(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "mistral") .tag(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), "mistral-42") .meters()).hasSize(3); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.INPUT.value()) .meters()).hasSize(1); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.OUTPUT.value()) .meters()).hasSize(1); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.TOTAL.value()) .meters()).hasSize(1); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java index 8c3bbb0cc66..0678fe26ad4 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java index 5c8951de3b3..6681402f4ad 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.observation.conventions.AiObservationAttributes; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java index ac59c321ad3..ebc4685ee51 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java index 7fc11e39d5c..e422a9b40c3 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.ImageMessage; import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.observation.conventions.AiObservationAttributes; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -39,7 +41,7 @@ class ImageModelPromptContentObservationFilterTests { @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -51,7 +53,7 @@ void whenEmptyPromptThenReturnOriginalContext() { .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -63,7 +65,7 @@ void whenPromptWithTextThenAugmentContext() { .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()) .contains(KeyValue.of(AiObservationAttributes.PROMPT.value(), "[\"supercalifragilisticexpialidocious\"]")); @@ -77,7 +79,7 @@ void whenPromptWithMessagesThenAugmentContext() { .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()) .contains(KeyValue.of(AiObservationAttributes.PROMPT.value(), diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java index ebeabb4fa0c..0e6353e828a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.metadata; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; +package org.springframework.ai.metadata; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + /** * Unit Tests for {@link PromptMetadata}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java index 18d7d63922d..cac20367870 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.metadata; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.metadata.Usage; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doReturn; @@ -23,9 +28,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.metadata.Usage; - /** * Unit Tests for {@link Usage}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/ModelOptionsUtilsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/ModelOptionsUtilsTests.java index 7b03f7e2d71..a764ccea2e6 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/ModelOptionsUtilsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/ModelOptionsUtilsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; import java.util.Map; @@ -28,104 +29,6 @@ */ public class ModelOptionsUtilsTests { - public static interface TestPortableOptions extends ModelOptions { - - String getName(); - - void setName(String name); - - Integer getAge(); - - void setAge(Integer age); - - } - - public static class TestPortableOptionsImpl implements TestPortableOptions { - - private String name; - - private Integer age; - - // Non interface fields - private String nonInterfaceField; - - @Override - public String getName() { - return name; - } - - @Override - public void setName(String name) { - this.name = name; - } - - @Override - public Integer getAge() { - return age; - } - - @Override - public void setAge(Integer age) { - this.age = age; - } - - public String getNonInterfaceField() { - return nonInterfaceField; - } - - public void setNonInterfaceField(String nonInterfaceField) { - this.nonInterfaceField = nonInterfaceField; - } - - } - - public static class TestSpecificOptions implements TestPortableOptions { - - @JsonProperty("specificField") - private String specificField; - - @JsonProperty("name") - private String name; - - @JsonProperty("age") - private Integer age; - - @Override - public String getName() { - return name; - } - - @Override - public void setName(String name) { - this.name = name; - } - - @Override - public Integer getAge() { - return age; - } - - @Override - public void setAge(Integer age) { - this.age = age; - } - - public String getSpecificField() { - return specificField; - } - - public void setSpecificField(String modelSpecificField) { - this.specificField = modelSpecificField; - } - - @Override - public String toString() { - return "TestModelSpecificOptions{" + "specificField='" + specificField + '\'' + ", name='" + name + '\'' - + ", age=" + age + '}'; - } - - } - @Test public void merge() { TestPortableOptionsImpl portableOptions = new TestPortableOptionsImpl(); @@ -145,7 +48,7 @@ public void merge() { assertThat(specificOptions2.getAge()).isEqualTo(30); assertThat(specificOptions2.getName()).isEqualTo("John"); // !!! Overridden by the - // portableOptions + // portableOptions assertThat(specificOptions2.getSpecificField()).isEqualTo("SpecificField"); } @@ -221,9 +124,108 @@ public void copyToTarget() { @Test public void getJsonPropertyValues() { record TestRecord(@JsonProperty("field1") String fieldA, @JsonProperty("field2") String fieldB) { + } assertThat(ModelOptionsUtils.getJsonPropertyValues(TestRecord.class)).hasSize(2); assertThat(ModelOptionsUtils.getJsonPropertyValues(TestRecord.class)).containsExactly("field1", "field2"); } -} \ No newline at end of file + public interface TestPortableOptions extends ModelOptions { + + String getName(); + + void setName(String name); + + Integer getAge(); + + void setAge(Integer age); + + } + + public static class TestPortableOptionsImpl implements TestPortableOptions { + + private String name; + + private Integer age; + + // Non interface fields + private String nonInterfaceField; + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public Integer getAge() { + return this.age; + } + + @Override + public void setAge(Integer age) { + this.age = age; + } + + public String getNonInterfaceField() { + return this.nonInterfaceField; + } + + public void setNonInterfaceField(String nonInterfaceField) { + this.nonInterfaceField = nonInterfaceField; + } + + } + + public static class TestSpecificOptions implements TestPortableOptions { + + @JsonProperty("specificField") + private String specificField; + + @JsonProperty("name") + private String name; + + @JsonProperty("age") + private Integer age; + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public Integer getAge() { + return this.age; + } + + @Override + public void setAge(Integer age) { + this.age = age; + } + + public String getSpecificField() { + return this.specificField; + } + + public void setSpecificField(String modelSpecificField) { + this.specificField = modelSpecificField; + } + + @Override + public String toString() { + return "TestModelSpecificOptions{" + "specificField='" + this.specificField + '\'' + ", name='" + this.name + + '\'' + ", age=" + this.age + '}'; + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java index b5ac63a5200..69b60b35519 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java index f4647be23c9..fb532d9ce3c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.lang.reflect.Type; @@ -39,8 +40,8 @@ class TypeResolverHelperIT { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction" }) void beanInputTypeResolutionTest(String beanName) { - assertThat(applicationContext).isNotNull(); - Type beanType = FunctionContextUtils.findType(applicationContext.getBeanFactory(), beanName); + assertThat(this.applicationContext).isNotNull(); + Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName); assertThat(beanType).isNotNull(); Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0); assertThat(functionInputType).isNotNull(); @@ -49,9 +50,11 @@ void beanInputTypeResolutionTest(String beanName) { } public record WeatherRequest(String city) { + } public record WeatherResponse(float temperatureInCelsius) { + } public static class Outer { @@ -70,17 +73,17 @@ public WeatherResponse apply(WeatherRequest weatherRequest) { @SpringBootConfiguration public static class TypeResolverHelperConfiguration { - @Bean() + @Bean Outer.InnerWeatherFunction weatherClassDefinition() { return new Outer.InnerWeatherFunction(); } - @Bean() + @Bean Function weatherFunctionDefinition() { return new Outer.InnerWeatherFunction(); } - @Bean() + @Bean StandaloneWeatherFunction standaloneWeatherFunction() { return new StandaloneWeatherFunction(); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java index 76622a22281..8051fe9912f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.function.Function; @@ -27,7 +28,7 @@ import org.springframework.ai.model.function.TypeResolverHelperTests.MockWeatherService.Request; import org.springframework.ai.model.function.TypeResolverHelperTests.MockWeatherService.Response; -import static org.assertj.core.api.Assertions.assertThat;; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -64,6 +65,11 @@ public String apply(Response response) { public static class MockWeatherService implements Function { + @Override + public Response apply(Request request) { + return new Response(10, "C"); + } + /** * Weather Function request. */ @@ -75,14 +81,11 @@ public record Request(@JsonProperty(required = true, @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") String unit) { + } public record Response(double temp, String unit) { - } - @Override - public Response apply(Request request) { - return new Response(10, "C"); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelObservationContextTests.java index 7bb47f11ccd..ce69e72355c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelObservationContextTests.java @@ -1,6 +1,23 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model.observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java index 53949df144f..3ef061335c0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.observation; import io.micrometer.common.KeyValue; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; import org.springframework.ai.observation.conventions.AiObservationMetricNames; @@ -86,7 +88,7 @@ static class TestUsage implements Usage { private final Long totalTokens; - public TestUsage(Long promptTokens, Long generationTokens, Long totalTokens) { + TestUsage(Long promptTokens, Long generationTokens, Long totalTokens) { this.promptTokens = promptTokens; this.generationTokens = generationTokens; this.totalTokens = totalTokens; @@ -94,17 +96,17 @@ public TestUsage(Long promptTokens, Long generationTokens, Long totalTokens) { @Override public Long getPromptTokens() { - return promptTokens; + return this.promptTokens; } @Override public Long getGenerationTokens() { - return generationTokens; + return this.generationTokens; } @Override public Long getTotalTokens() { - return totalTokens; + return this.totalTokens; } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/observation/AiOperationMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/observation/AiOperationMetadataTests.java index 59e822ddfd4..d538245b2d7 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/observation/AiOperationMetadataTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/observation/AiOperationMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation; import org.junit.jupiter.api.Test; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/observation/tracing/TracingHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/observation/tracing/TracingHelperTests.java index aa5848370ba..78ed778b2a8 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/observation/tracing/TracingHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/observation/tracing/TracingHelperTests.java @@ -1,17 +1,32 @@ -package org.springframework.ai.observation.tracing; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.observation.tracing; import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.Test; - import io.micrometer.tracing.Span; import io.micrometer.tracing.TraceContext; import io.micrometer.tracing.handler.TracingObservationHandler; import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; import io.opentelemetry.api.OpenTelemetry; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link TracingHelper}. @@ -125,4 +140,4 @@ public Span remoteIpAndPort(String s, int i) { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/ChatTests.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/ChatTests.java index 836f3cfc8c0..711889d2ce1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/ChatTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/ChatTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.prompt; public class ChatTests { diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java index 835bd59e790..096d5e21cca 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.prompt; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; @@ -25,14 +35,6 @@ import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.Resource; -import java.io.ByteArrayInputStream; -import java.io.InputStream; -import java.nio.charset.Charset; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -40,6 +42,18 @@ public class PromptTemplateTest { + private static Map createTestMap() { + Map model = new HashMap<>(); + model.put("key1", "value1"); + model.put("key2", true); + return model; + } + + private static void assertEqualsWithNormalizedEOLs(String expected, String actual) { + assertEquals(expected.replaceAll("\\r\\n|\\r|\\n", System.lineSeparator()), + actual.replaceAll("\\r\\n|\\r|\\n", System.lineSeparator())); + } + @Test public void testCreateWithEmptyModelAndChatOptions() { String template = "This is a test prompt with no variables"; @@ -154,13 +168,6 @@ public void testRenderResource() { assertEquals(expected, result); } - private static Map createTestMap() { - Map model = new HashMap<>(); - model.put("key1", "value1"); - model.put("key2", true); - return model; - } - @Disabled("Need to improve PromptTemplate to better handle Resource toString and tracking with 'dynamicModel' for underlying StringTemplate") @Test public void testRenderResourceAsValue() throws Exception { @@ -199,9 +206,4 @@ public void testRenderFailure() { assertThrows(IllegalStateException.class, promptTemplate::render); } - private static void assertEqualsWithNormalizedEOLs(String expected, String actual) { - assertEquals(expected.replaceAll("\\r\\n|\\r|\\n", System.lineSeparator()), - actual.replaceAll("\\r\\n|\\r|\\n", System.lineSeparator())); - } - -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java index 62bb62b8056..7c25b8d3065 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.prompt; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - import static org.assertj.core.api.Assertions.assertThat; @SuppressWarnings("unchecked") @@ -83,7 +85,7 @@ void newApiPlaygroundTests() { Prompt systemPrompt = promptTemplate.create(systemModel); promptTemplate = new PromptTemplate(humanTemplate); // creates a Prompt with - // HumanMessage + // HumanMessage Prompt humanPrompt = promptTemplate.create(humanModel); // ChatPromptTemplate chatPromptTemplate = new ChatPromptTemplate(systemPrompt, diff --git a/spring-ai-core/src/test/java/org/springframework/ai/reader/JsonReaderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/reader/JsonReaderTests.java index b57bc99c362..af7aa19d7f9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/reader/JsonReaderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/reader/JsonReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest @@ -39,8 +41,8 @@ public class JsonReaderTests { @Test void loadJsonArray() { - assertThat(arrayResource).isNotNull(); - JsonReader jsonReader = new JsonReader(arrayResource, "description"); + assertThat(this.arrayResource).isNotNull(); + JsonReader jsonReader = new JsonReader(this.arrayResource, "description"); List documents = jsonReader.get(); assertThat(documents).isNotEmpty(); for (Document document : documents) { @@ -50,8 +52,8 @@ void loadJsonArray() { @Test void loadJsonObject() { - assertThat(ObjectResource).isNotNull(); - JsonReader jsonReader = new JsonReader(ObjectResource, "description"); + assertThat(this.ObjectResource).isNotNull(); + JsonReader jsonReader = new JsonReader(this.ObjectResource, "description"); List documents = jsonReader.get(); assertThat(documents).isNotEmpty(); for (Document document : documents) { @@ -61,8 +63,8 @@ void loadJsonObject() { @Test void loadJsonArrayFromPointer() { - assertThat(arrayResource).isNotNull(); - JsonReader jsonReader = new JsonReader(eventsResource, "description"); + assertThat(this.arrayResource).isNotNull(); + JsonReader jsonReader = new JsonReader(this.eventsResource, "description"); List documents = jsonReader.get("/0/sessions"); assertThat(documents).isNotEmpty(); for (Document document : documents) { @@ -73,8 +75,8 @@ void loadJsonArrayFromPointer() { @Test void loadJsonObjectFromPointer() { - assertThat(ObjectResource).isNotNull(); - JsonReader jsonReader = new JsonReader(ObjectResource, "name"); + assertThat(this.ObjectResource).isNotNull(); + JsonReader jsonReader = new JsonReader(this.ObjectResource, "name"); List documents = jsonReader.get("/store"); assertThat(documents).isNotEmpty(); assertThat(documents.size()).isEqualTo(1); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/reader/TextReaderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/reader/TextReaderTests.java index 3db8952a415..835d17f777a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/reader/TextReaderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/reader/TextReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; -import java.io.File; -import java.io.IOException; -import java.net.URI; -import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.List; @@ -28,7 +25,6 @@ import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.DefaultResourceLoader; -import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import static org.assertj.core.api.Assertions.assertThat; @@ -105,4 +101,4 @@ void loadTextFromByteArrayResource() { assertThat(customDocument.getContent()).isEqualTo("Another test content"); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TextSplitterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TextSplitterTests.java index a5caf706ee0..c3172380dac 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TextSplitterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TextSplitterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer.splitter; import java.util.ArrayList; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java index 0baefc0acb9..c30225b5cca 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java @@ -1,12 +1,29 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.transformer.splitter; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; + import org.springframework.ai.document.DefaultContentFormatter; import org.springframework.ai.document.Document; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java index 18d0b3424db..12084d00797 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -31,8 +32,8 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; -import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NOT; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; /** * @author Christian Tzolov @@ -44,13 +45,14 @@ public class FilterExpressionBuilderTests { @Test public void testEQ() { // country == "BG" - assertThat(b.eq("country", "BG").build()).isEqualTo(new Expression(EQ, new Key("country"), new Value("BG"))); + assertThat(this.b.eq("country", "BG").build()) + .isEqualTo(new Expression(EQ, new Key("country"), new Value("BG"))); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - Expression exp = b.and(b.eq("genre", "drama"), b.gte("year", 2020)).build(); + Expression exp = this.b.and(this.b.eq("genre", "drama"), this.b.gte("year", 2020)).build(); assertThat(exp).isEqualTo(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); } @@ -58,7 +60,7 @@ public void tesEqAndGte() { @Test public void testIn() { // genre in ["comedy", "documentary", "drama"] - var exp = b.in("genre", "comedy", "documentary", "drama").build(); + var exp = this.b.in("genre", "comedy", "documentary", "drama").build(); assertThat(exp) .isEqualTo(new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); } @@ -66,7 +68,9 @@ public void testIn() { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - var exp = b.and(b.or(b.gte("year", 2020), b.eq("country", "BG")), b.ne("city", "Sofia")).build(); + var exp = this.b + .and(this.b.or(this.b.gte("year", 2020), this.b.eq("country", "BG")), this.b.ne("city", "Sofia")) + .build(); assertThat(exp).isEqualTo(new Expression(AND, new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), @@ -77,7 +81,9 @@ public void testNe() { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - var exp = b.and(b.group(b.or(b.gte("year", 2020), b.eq("country", "BG"))), b.nin("city", "Sofia", "Plovdiv")) + var exp = this.b + .and(this.b.group(this.b.or(this.b.gte("year", 2020), this.b.eq("country", "BG"))), + this.b.nin("city", "Sofia", "Plovdiv")) .build(); assertThat(exp).isEqualTo(new Expression(AND, @@ -89,7 +95,10 @@ public void testGroup() { @Test public void tesIn2() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - var exp = b.and(b.and(b.eq("isOpen", true), b.gte("year", 2020)), b.in("country", "BG", "NL", "US")).build(); + var exp = this.b + .and(this.b.and(this.b.eq("isOpen", true), this.b.gte("year", 2020)), + this.b.in("country", "BG", "NL", "US")) + .build(); assertThat(exp).isEqualTo(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), @@ -100,7 +109,8 @@ public void tesIn2() { @Test public void tesNot() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - var exp = b.not(b.and(b.and(b.eq("isOpen", true), b.gte("year", 2020)), b.in("country", "BG", "NL", "US"))) + var exp = this.b.not(this.b.and(this.b.and(this.b.eq("isOpen", true), this.b.gte("year", 2020)), + this.b.in("country", "BG", "NL", "US"))) .build(); assertThat(exp).isEqualTo(new Expression(NOT, diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java index 8253fb234c4..2c16705e22c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -45,55 +46,55 @@ public class FilterExpressionTextParserTests { @Test public void testEQ() { // country == "BG" - Expression exp = parser.parse("country == 'BG'"); + Expression exp = this.parser.parse("country == 'BG'"); assertThat(exp).isEqualTo(new Expression(EQ, new Key("country"), new Value("BG"))); - assertThat(parser.getCache().get("WHERE " + "country == 'BG'")).isEqualTo(exp); + assertThat(this.parser.getCache().get("WHERE " + "country == 'BG'")).isEqualTo(exp); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - Expression exp = parser.parse("genre == 'drama' && year >= 2020"); + Expression exp = this.parser.parse("genre == 'drama' && year >= 2020"); assertThat(exp).isEqualTo(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); - assertThat(parser.getCache().get("WHERE " + "genre == 'drama' && year >= 2020")).isEqualTo(exp); + assertThat(this.parser.getCache().get("WHERE " + "genre == 'drama' && year >= 2020")).isEqualTo(exp); } @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - Expression exp = parser.parse("genre in ['comedy', 'documentary', 'drama']"); + Expression exp = this.parser.parse("genre in ['comedy', 'documentary', 'drama']"); assertThat(exp) .isEqualTo(new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); - assertThat(parser.getCache().get("WHERE " + "genre in ['comedy', 'documentary', 'drama']")).isEqualTo(exp); + assertThat(this.parser.getCache().get("WHERE " + "genre in ['comedy', 'documentary', 'drama']")).isEqualTo(exp); } @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - Expression exp = parser.parse("year >= 2020 OR country == \"BG\" AND city != \"Sofia\""); + Expression exp = this.parser.parse("year >= 2020 OR country == \"BG\" AND city != \"Sofia\""); assertThat(exp).isEqualTo(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); - assertThat(parser.getCache().get("WHERE " + "year >= 2020 OR country == \"BG\" AND city != \"Sofia\"")) + assertThat(this.parser.getCache().get("WHERE " + "year >= 2020 OR country == \"BG\" AND city != \"Sofia\"")) .isEqualTo(exp); } @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - Expression exp = parser.parse("(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]"); + Expression exp = this.parser.parse("(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]"); assertThat(exp).isEqualTo(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); - assertThat(parser.getCache() + assertThat(this.parser.getCache() .get("WHERE " + "(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]")) .isEqualTo(exp); } @@ -101,20 +102,21 @@ public void testGroup() { @Test public void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - Expression exp = parser.parse("isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]"); + Expression exp = this.parser.parse("isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]"); assertThat(exp).isEqualTo(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); - assertThat(parser.getCache() + assertThat(this.parser.getCache() .get("WHERE " + "isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]")).isEqualTo(exp); } @Test public void tesNot() { // NOT(isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"]) - Expression exp = parser.parse("not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])"); + Expression exp = this.parser + .parse("not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])"); assertThat(exp).isEqualTo(new Expression(NOT, new Group(new Expression(AND, @@ -123,7 +125,7 @@ public void tesNot() { new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))), null)); - assertThat(parser.getCache() + assertThat(this.parser.getCache() .get("WHERE " + "not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])")) .isEqualTo(exp); } @@ -131,7 +133,7 @@ public void tesNot() { @Test public void tesNotNin() { // NOT(country NOT IN ["BG", "NL", "US"]) - Expression exp = parser.parse("not(country NOT IN [\"BG\", \"NL\", \"US\"])"); + Expression exp = this.parser.parse("not(country NOT IN [\"BG\", \"NL\", \"US\"])"); assertThat(exp).isEqualTo(new Expression(NOT, new Group(new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US")))), null)); @@ -140,7 +142,7 @@ public void tesNotNin() { @Test public void tesNotNin2() { // NOT country NOT IN ["BG", "NL", "US"] - Expression exp = parser.parse("NOT country NOT IN [\"BG\", \"NL\", \"US\"]"); + Expression exp = this.parser.parse("NOT country NOT IN [\"BG\", \"NL\", \"US\"]"); assertThat(exp).isEqualTo(new Expression(NOT, new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US"))), null)); @@ -149,7 +151,7 @@ public void tesNotNin2() { @Test public void tesNestedNot() { // NOT(isOpen == true AND year >= 2020 AND NOT(country IN ["BG", "NL", "US"])) - Expression exp = parser + Expression exp = this.parser .parse("not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))"); assertThat(exp).isEqualTo(new Expression(NOT, @@ -161,7 +163,7 @@ public void tesNestedNot() { null))), null)); - assertThat(parser.getCache() + assertThat(this.parser.getCache() .get("WHERE " + "not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))")) .isEqualTo(exp); } @@ -170,23 +172,23 @@ public void tesNestedNot() { public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 String expText = "temperature >= -15.6 && temperature <= +20.13"; - Expression exp = parser.parse(expText); + Expression exp = this.parser.parse(expText); assertThat(exp).isEqualTo(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); - assertThat(parser.getCache().get("WHERE " + expText)).isEqualTo(exp); + assertThat(this.parser.getCache().get("WHERE " + expText)).isEqualTo(exp); } @Test public void testIdentifiers() { - Expression exp = parser.parse("'country.1' == 'BG'"); + Expression exp = this.parser.parse("'country.1' == 'BG'"); assertThat(exp).isEqualTo(new Expression(EQ, new Key("'country.1'"), new Value("BG"))); - exp = parser.parse("'country_1_2_3' == 'BG'"); + exp = this.parser.parse("'country_1_2_3' == 'BG'"); assertThat(exp).isEqualTo(new Expression(EQ, new Key("'country_1_2_3'"), new Value("BG"))); - exp = parser.parse("\"country 1 2 3\" == 'BG'"); + exp = this.parser.parse("\"country 1 2 3\" == 'BG'"); assertThat(exp).isEqualTo(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java index 472e9b1d881..df793ecf06f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -165,6 +166,6 @@ else if (expression.type() == ExpressionType.NIN) { } } - }; + } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java index 5535766b793..818ccfcae66 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import org.junit.jupiter.api.Test; @@ -67,7 +68,7 @@ public void withQuery() { assertThat(emptyRequest.getQuery()).isEqualTo("New Query"); } - @Test() + @Test public void withSimilarityThreshold() { var request = SearchRequest.query("Test").withSimilarityThreshold(0.678); assertThat(request.getSimilarityThreshold()).isEqualTo(0.678); @@ -87,7 +88,7 @@ public void withSimilarityThreshold() { } - @Test() + @Test public void withTopK() { var request = SearchRequest.query("Test").withTopK(66); assertThat(request.getTopK()).isEqualTo(66); @@ -101,7 +102,7 @@ public void withTopK() { } - @Test() + @Test public void withFilterExpression() { var request = SearchRequest.query("Test").withFilterExpression("country == 'BG' && year >= 2022"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverterTests.java index e86b927f017..9fc858aa120 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter.converter; import java.util.List; @@ -45,14 +46,14 @@ public class PineconeFilterExpressionConverterTests { @Test public void testEQ() { // country == "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"country\": {\"$eq\": \"BG\"}}"); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr) @@ -62,7 +63,7 @@ public void tesEqAndGte() { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("{\"genre\": {\"$in\": [\"comedy\",\"documentary\",\"drama\"]}}"); } @@ -70,7 +71,7 @@ public void tesIn() { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -81,7 +82,7 @@ public void testNe() { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); @@ -92,7 +93,7 @@ public void testGroup() { @Test public void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -104,7 +105,7 @@ public void tesBoolean() { @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -114,11 +115,11 @@ public void testDecimal() { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"country 1 2 3\": {\"$eq\": \"BG\"}}"); - vectorExpr = converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); + vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"country 1 2 3\": {\"$eq\": \"BG\"}}"); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConventionTests.java index 5882f92858b..981ac04b13b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore.observation; import java.util.List; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames; -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link DefaultVectorStoreObservationConvention}. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContextTests.java index 06d54315135..6f6abd87355 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.observation; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import org.junit.jupiter.api.Test; - /** * Unit tests for {@link VectorStoreObservationContext}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilterTests.java index 652c288623a..ba7a37e0523 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore.observation; import java.util.List; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link VectorStoreQueryResponseObservationFilter}. @@ -39,7 +40,7 @@ class VectorStoreQueryResponseObservationFilterTests { @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -49,7 +50,7 @@ void whenEmptyQueryResponseThenReturnOriginalContext() { var expectedContext = VectorStoreObservationContext.builder("db", VectorStoreObservationContext.Operation.ADD) .build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -63,7 +64,7 @@ void whenNonEmptyQueryResponseThenAugmentContext() { expectedContext.setQueryResponse(queryResponseDocs); - var augmentedContext = observationFilter.map(expectedContext); + var augmentedContext = this.observationFilter.map(expectedContext); assertThat(augmentedContext.getHighCardinalityKeyValues()).contains(KeyValue .of(HighCardinalityKeyNames.DB_VECTOR_QUERY_RESPONSE_DOCUMENTS.asString(), "[\"doc1\", \"doc2\"]")); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandlerTests.java index 657f5555fed..499c3cc0216 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.observation; +import java.util.List; + import io.micrometer.tracing.handler.TracingObservationHandler; import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; @@ -22,13 +25,12 @@ import io.opentelemetry.sdk.trace.ReadableSpan; import io.opentelemetry.sdk.trace.SdkTracerProvider; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.observation.conventions.VectorStoreObservationAttributes; import org.springframework.ai.observation.conventions.VectorStoreObservationEventNames; import org.springframework.ai.observation.tracing.TracingHelper; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/spring-ai-core/src/test/resources/application-logging-test.properties b/spring-ai-core/src/test/resources/application-logging-test.properties index 8ba46b8d771..8c5bc06a73b 100644 --- a/spring-ai-core/src/test/resources/application-logging-test.properties +++ b/spring-ai-core/src/test/resources/application-logging-test.properties @@ -1,2 +1,17 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# logging.level.org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor=DEBUG logging.level.ch.qos.logback=ERROR diff --git a/spring-ai-core/src/test/resources/bikes.json b/spring-ai-core/src/test/resources/bikes.json index 4865975154c..62a75ebed94 100644 --- a/spring-ai-core/src/test/resources/bikes.json +++ b/spring-ai-core/src/test/resources/bikes.json @@ -1,265 +1,265 @@ [ - { - "name": "E-Adrenaline 8.0 EX1", - "shortDescription": "a versatile and comfortable e-MTB designed for adrenaline enthusiasts who want to explore all types of terrain. It features a powerful motor and advanced suspension to provide a smooth and responsive ride, with a variety of customizable settings to fit any rider's needs.", - "description": "## Overview\r\nIt's right for you if...\r\nYou want to push your limits on challenging trails and terrain, with the added benefit of an electric assist to help you conquer steep climbs and rough terrain. You also want a bike with a comfortable and customizable fit, loaded with high-quality components and technology.\r\n\r\nThe tech you get\r\nA lightweight, full ADV Mountain Carbon frame with a customizable geometry, including an adjustable head tube and chainstay length. A powerful and efficient motor with a 375Wh battery that can assist up to 28 mph when it's on, and provides a smooth and seamless transition when it's off. A SRAM EX1 8-speed drivetrain, a RockShox Lyrik Ultimate fork, and a RockShox Super Deluxe Ultimate rear shock.\r\n\r\nThe final word\r\nOur E-Adrenaline 8.0 EX1 is the perfect bike for adrenaline enthusiasts who want to explore all types of terrain. It's versatile, comfortable, and loaded with advanced technology to provide a smooth and responsive ride, no matter where your adventures take you.\r\n\r\n\r\n## Features\r\nVersatile and customizable\r\nThe E-Adrenaline 8.0 EX1 features a customizable geometry, including an adjustable head tube and chainstay length, so you can fine-tune your ride to fit your needs and preferences. It also features a variety of customizable settings, including suspension tuning, motor assistance levels, and more.\r\n\r\nPowerful and efficient\r\nThe bike is equipped with a powerful and efficient motor that provides a smooth and seamless transition between human power and electric assist. It can assist up to 28 mph when it's on, and provides zero drag when it's off.\r\n\r\nAdvanced suspension\r\nThe E-Adrenaline 8.0 EX1 features a RockShox Lyrik Ultimate fork and a RockShox Super Deluxe Ultimate rear shock, providing advanced suspension technology to absorb shocks and bumps on any terrain. The suspension is also customizable to fit your riding style and preferences.\r\n\r\n\r\n## Specs\r\nFrameset\r\nFrame ADV Mountain Carbon main frame & stays, adjustable head tube and chainstay length, tapered head tube, Knock Block, Control Freak internal routing, Boost148, 150mm travel\r\nFork RockShox Lyrik Ultimate, DebonAir spring, Charger 2.1 RC2 damper, remote lockout, tapered steerer, 42mm offset, Boost110, 15mm Maxle Stealth, 160mm travel\r\nShock RockShox Super Deluxe Ultimate, DebonAir spring, Thru Shaft 3-position damper, 230x57.5mm\r\n\r\nWheels\r\nWheel front Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 6-bolt, Boost110, 15mm thru axle\r\nWheel rear Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 54T Rapid Drive, 6-bolt, Shimano MicroSpline freehub, Boost148, 12mm thru axle\r\nSkewer rear Bontrager Switch thru axle, removable lever\r\nTire Bontrager XR5 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.50''\r\nTire part Bontrager TLR sealant, 6oz\r\n\r\nDrivetrain\r\nShifter SRAM EX1, 8 speed\r\nRear derailleur SRAM EX1, 8 speed\r\nCrank Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nChainring SRAM EX1, 18T, steel\r\nCassette SRAM EX1, 11-48, 8 speed\r\nChain SRAM EX1, 8 speed\r\n\r\nComponents\r\nSaddle Bontrager Arvada, hollow chromoly rails, 138mm width\r\nSeatpost Bontrager Line Elite Dropper, internal routing, 31.6mm\r\nHandlebar Bontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\r\nGrips Bontrager XR Trail Elite, alloy lock-on\r\nStem Bontrager Line Pro, 35mm, Knock Block, Blendr compatible, 0 degree, 50mm length\r\nHeadset Knock Block Integrated, 62-degree radius, cartridge bearing, 1-1\/8'' top, 1.5'' bottom\r\nBrake SRAM G2 RSC hydraulic disc, carbon levers\r\nBrake rotor SRAM Centerline, centerlock, round edge, 200mm\r\n\r\nAccessories\r\nE-bike system Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nBattery Bosch PowerTube 625, 625Wh\r\nCharger Bosch 4A standard charger\r\nController Bosch Kiox with Anti-theft solution, Bluetooth connectivity, 1.9'' display\r\nTool Bontrager Switch thru axle, removable lever\r\n\r\nWeight\r\nWeight M - 20.25 kg \/ 44.6 lbs (with TLR sealant, no tubes)\r\nWeight limit This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\r\n\r\n## Sizing & fit\r\n\r\n| Size | Rider Height | Inseam |\r\n|:----:|:------------------------:|:--------------------:|\r\n| S | 155 - 170 cm 5'1\" - 5'7\" | 73 - 80 cm 29\" - 31.5\" |\r\n| M | 163 - 178 cm 5'4\" - 5'10\" | 77 - 83 cm 30.5\" - 32.5\" |\r\n| L | 176 - 191 cm 5'9\" - 6'3\" | 83 - 89 cm 32.5\" - 35\" |\r\n| XL | 188 - 198 cm 6'2\" - 6'6\" | 88 - 93 cm 34.5\" - 36.5\" |\r\n\r\n\r\n## Geometry\r\n\r\nAll measurements provided in cm unless otherwise noted.\r\nSizing table\r\n| Frame size letter | S | M | L | XL |\r\n|---------------------------|-------|-------|-------|-------|\r\n| Actual frame size | 15.8 | 17.8 | 19.8 | 21.8 |\r\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\r\n| A \u2014 Seat tube | 40.0 | 42.5 | 47.5 | 51.0 |\r\n| B \u2014 Seat tube angle | 72.5\u00B0 | 72.8\u00B0 | 73.0\u00B0 | 73.0\u00B0 |\r\n| C \u2014 Head tube length | 9.5 | 10.5 | 11.0 | 11.5 |\r\n| D \u2014 Head angle | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 |\r\n| E \u2014 Effective top tube | 59.0 | 62.0 | 65.0 | 68.0 |\r\n| F \u2014 Bottom bracket height | 32.5 | 32.5 | 32.5 | 32.5 |\r\n| G \u2014 Bottom bracket drop | 5.5 | 5.5 | 5.5 | 5.5 |\r\n| H \u2014 Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\r\n| I \u2014 Offset | 4.5 | 4.5 | 4.5 | 4.5 |\r\n| J \u2014 Trail | 11.0 | 11.0 | 11.0 | 11.0 |\r\n| K \u2014 Wheelbase | 113.0 | 117.0 | 120.0 | 123.0 |\r\n| L \u2014 Standover | 77.0 | 77.0 | 77.0 | 77.0 |\r\n| M \u2014 Frame reach | 41.0 | 44.5 | 47.5 | 50.0 |\r\n| N \u2014 Frame stack | 61.0 | 62.0 | 62.5 | 63.0 |", - "price": 1499.99, - "tags": [ - "bicycle" - ] - }, - { - "name": "Enduro X Pro", - "shortDescription": "The Enduro X Pro is the ultimate mountain bike for riders who demand the best. With its full carbon frame and top-of-the-line components, this bike is ready to tackle any trail, from technical downhill descents to grueling uphill climbs.", - "text": "## Overview\nIt's right for you if...\nYou're an experienced mountain biker who wants a high-performance bike that can handle any terrain. You want a bike with the best components available, including a full carbon frame, suspension system, and hydraulic disc brakes.\n\nThe tech you get\nOur top-of-the-line full carbon frame with aggressive geometry and a slack head angle for maximum control. It's equipped with a Fox Factory suspension system with 170mm of travel in the front and 160mm in the rear, a Shimano XTR 12-speed drivetrain, and hydraulic disc brakes for maximum stopping power. The bike also features a dropper seatpost for easy adjustments on the fly.\n\nThe final word\nThe Enduro X Pro is the ultimate mountain bike for riders who demand the best. With its full carbon frame, top-of-the-line components, and aggressive geometry, this bike is ready to take on any trail. Whether you're a seasoned pro or just starting out, the Enduro X Pro will help you take your riding to the next level.\n\n## Features\nFull carbon frame\nAggressive geometry with a slack head angle\nFox Factory suspension system with 170mm of travel in the front and 160mm in the rear\nShimano XTR 12-speed drivetrain\nHydraulic disc brakes for maximum stopping power\nDropper seatpost for easy adjustments on the fly\n\n## Specifications\nFrameset\nFrame\tFull carbon frame\nFork\tFox Factory suspension system with 170mm of travel\nRear suspension\tFox Factory suspension system with 160mm of travel\n\nWheels\nWheel size\t27.5\" or 29\"\nTires\tTubeless-ready Maxxis tires\n\nDrivetrain\nShifters\tShimano XTR 12-speed\nFront derailleur\tN/A\nRear derailleur\tShimano XTR\nCrankset\tShimano XTR\nCassette\tShimano XTR 12-speed\nChain\tShimano XTR\n\nComponents\nBrakes\tHydraulic disc brakes\nHandlebar\tAlloy handlebar\nStem\tAlloy stem\nSeatpost\tDropper seatpost\n\nAccessories\nPedals\tNot included\n\nWeight\nWeight\tApproximately 27-29 lbs\n\n## Sizing\n| Size | Rider Height |\n|:----:|:-------------------------:|\n| S | 5'4\" - 5'8\" (162-172cm) |\n| M | 5'8\" - 5'11\" (172-180cm) |\n| L | 5'11\" - 6'3\" (180-191cm) |\n| XL | 6'3\" - 6'6\" (191-198cm) |\n\n## Geometry\n| Size | S | M | L | XL |\n|:----:|:---------------:|:---------------:|:-----------------:|:---------------:|\n| A - Seat tube length | 390mm | 425mm | 460mm | 495mm |\n| B - Effective top tube length | 585mm | 610mm | 635mm | 660mm |\n| C - Head tube angle | 65.5° | 65.5° | 65.5° | 65.5° |\n| D - Seat tube angle | 76° | 76° | 76° | 76° |\n| E - Chainstay length | 435mm | 435mm | 435mm | 435mm |\n| F - Head tube length | 100mm | 110mm | 120mm | 130mm |\n| G - BB drop | 20mm | 20mm | 20mm | 20mm |\n| H - Wheelbase | 1155mm | 1180mm | 1205mm | 1230mm |\n| I - Standover height | 780mm | 800mm | 820mm | 840mm |\n| J - Reach | 425mm | 450mm | 475mm | 500mm |\n| K - Stack | 610mm | 620mm | 630mm | 640mm |", - "price": 599.99, - "tags": [ - "bicycle" - ] - }, - { - "name": "Blaze X1", - "shortDescription": "Blaze X1 is a high-performance road bike that offers superior speed and agility, making it perfect for competitive racing or fast-paced group rides. The bike features a lightweight carbon frame, aerodynamic tube shapes, a 12-speed Shimano Ultegra drivetrain, and hydraulic disc brakes for precise stopping power. With its sleek design and cutting-edge technology, Blaze X1 is a bike that is built to perform and dominate on any road.", - "description": "## Overview\nIt's right for you if...\nYou're a competitive road cyclist or an enthusiast who enjoys fast-paced group rides. You want a bike that is lightweight, agile, and delivers exceptional speed.\n\nThe tech you get\nBlaze X1 features a lightweight carbon frame with a tapered head tube and aerodynamic tube shapes for maximum speed and efficiency. The bike is equipped with a 12-speed Shimano Ultegra drivetrain for smooth and precise shifting, Shimano hydraulic disc brakes for powerful and reliable stopping power, and Bontrager Aeolus Elite 35 carbon wheels for increased speed and agility.\n\nThe final word\nBlaze X1 is a high-performance road bike that is designed to deliver exceptional speed and agility. With its cutting-edge technology and top-of-the-line components, it's a bike that is built to perform and dominate on any road.\n\n## Features\nSpeed and efficiency\nBlaze X1's lightweight carbon frame and aerodynamic tube shapes offer maximum speed and efficiency, allowing you to ride faster and farther with ease.\n\nPrecision stopping power\nShimano hydraulic disc brakes provide precise and reliable stopping power, even in wet or muddy conditions.\n\nAgility and control\nBontrager Aeolus Elite 35 carbon wheels make Blaze X1 incredibly agile and responsive, allowing you to navigate tight turns and corners with ease.\n\nSmooth and precise shifting\nThe 12-speed Shimano Ultegra drivetrain offers smooth and precise shifting, so you can easily find the right gear for any terrain.\n\n## Specifications\nFrameset\nFrame\tADV Carbon, tapered head tube, BB90, direct mount rim brakes, internal cable routing, DuoTrap S compatible, 130x9mm QR\nFork\tADV Carbon, tapered steerer, direct mount rim brakes, internal brake routing, 100x9mm QR\n\nWheels\nWheel front\tBontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, 100x9mm QR\nWheel rear\tBontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, Shimano 11-speed freehub, 130x9mm QR\nTire front\tBontrager R3 Hard-Case Lite, aramid bead, 120 tpi, 700x25c\nTire rear\tBontrager R3 Hard-Case Lite, aramid bead, 120 tpi, 700x25c\nMax tire size\t25c Bontrager tires (with at least 4mm of clearance to frame)\n\nDrivetrain\nShifter\tShimano Ultegra R8020, 12 speed\nFront derailleur\tShimano Ultegra R8000, braze-on\nRear derailleur\tShimano Ultegra R8000, short cage, 30T max cog\nCrank\tSize: 50, 52, 54\nShimano Ultegra R8000, 50/34 (compact), 170mm length\nSize: 56, 58, 60, 62\nShimano Ultegra R8000, 50/34 (compact), 172.5mm length\nBottom bracket\tBB90, Shimano press-fit\nCassette\tShimano Ultegra R8000, 11-30, 12 speed\nChain\tShimano Ultegra HG701, 12 speed\n\nComponents\nSaddle\tBontrager Montrose Elite, titanium rails, 138mm width\nSeatpost\tBontrager carbon seatmast cap, 20mm offset\nHandlebar\tBontrager Elite Aero VR-CF, alloy, 31.8mm, internal cable routing, 40cm width\nGrips\tBontrager Supertack Perf tape\nStem\tBontrager Elite, 31.8mm, Blendr-compatible, 7 degree, 80mm length\nBrake Shimano Ultegra hydraulic disc brake\n\nWeight\nWeight\t56 - 8.91 kg / 19.63 lbs (with tubes)\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider height |\n|------|-------------|\n| 50 | 162-166cm |\n| 52 | 165-170cm |\n| 54 | 168-174cm |\n| 56 | 174-180cm |\n| 58 | 179-184cm |\n| 60 | 184-189cm |\n| 62 | 189-196cm |\n\n## Geometry\n| Frame size | 50cm | 52cm | 54cm | 56cm | 58cm | 60cm | 62cm |\n|------------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A - Seat tube | 443mm | 460mm | 478mm | 500mm | 520mm | 540mm | 560mm |\n| B - Seat tube angle | 74.1° | 73.9° | 73.7° | 73.4° | 73.2° | 73.0° | 72.8° |\n| C - Head tube length | 100mm | 110mm | 130mm | 150mm | 170mm | 190mm | 210mm |\n| D - Head angle | 71.4° | 72.0° | 72.5° | 73.0° | 73.3° | 73.6° | 73.8° |\n| E - Effective top tube | 522mm | 535mm | 547mm | 562mm | 577mm | 593mm | 610mm |\n| F - Bottom bracket height | 268mm | 268mm | 268mm | 268mm | 268mm | 268mm | 268mm |\n| G - Bottom bracket drop | 69mm | 69mm | 69mm | 69mm | 69mm | 69mm | 69mm |\n| H - Chainstay length | 410mm | 410mm | 410mm | 410mm | 410mm | 410mm | 410mm |\n| I - Offset | 50mm | 50mm | 50mm | 50mm | 50mm | 50mm | 50mm |\n| J - Trail | 65mm | 62mm | 59mm | 56mm | 55mm | 53mm | 52mm |\n| K - Wheelbase | 983mm | 983mm | 990mm | 1005mm | 1019mm | 1036mm | 1055mm |\n| L - Standover | 741mm | 765mm | 787mm | 806mm | 825mm | 847mm | 869mm |", - "price": 799.99, - "tags": [ - "bicycle", - "mountain bike" - ] - }, - { - "name": "Celerity X5", - "shortDescription": "Celerity X5 is a versatile and reliable road bike that is designed for experienced and amateur riders alike. It's designed to provide smooth and comfortable rides over long distances. With an ultra-lightweight and responsive carbon fiber frame, Shimano 105 groupset, hydraulic disc brakes, and 28mm wide tires, this bike ensures efficient power transfer, precise handling, and superior stopping power.", - "description": "## Overview\n\nIt's right for you if... \nYou are looking for a high-performance road bike that offers a perfect balance of speed, comfort, and control. You enjoy long-distance rides and need a bike that is designed to handle various road conditions with ease. You also appreciate the latest technology and reliable components that make your riding experience more enjoyable.\n\nThe tech you get \nCelerity X5 is equipped with a full carbon fiber frame that ensures maximum strength and durability while keeping the weight down. It features a Shimano 105 groupset with 11-speed gearing for precise and efficient shifting. Hydraulic disc brakes offer superior stopping power, and 28mm wide tires provide comfort and stability on various road surfaces. Internal cable routing enhances the bike's sleek appearance.\n\nThe final word \nIf you are looking for a high-performance road bike that offers comfort, speed, and control, Celerity X5 is the perfect choice. With its lightweight carbon fiber frame, reliable components, and advanced technology, this bike is designed to help you enjoy long-distance rides with ease.\n\n## Features \n\nLightweight and responsive \nCelerity X5 comes with a full carbon fiber frame that is not only lightweight but also responsive, providing excellent handling and control.\n\nHydraulic disc brakes \nThis bike is equipped with hydraulic disc brakes that provide superior stopping power in all weather conditions, ensuring your safety and confidence on the road.\n\nComfortable rides \nThe 28mm wide tires and carbon seat post provide ample cushioning, ensuring a smooth and comfortable ride over long distances.\n\nSleek appearance \nThe bike's internal cable routing enhances its sleek appearance while also protecting the cables from the elements, ensuring smooth shifting for longer periods.\n\n## Specifications \n\nFrameset \nFrame\tCelerity X5 Full Carbon Fiber Frame, Internal Cable Routing, Tapered Headtube, Press Fit Bottom Bracket, 12x142mm Thru-Axle \nFork\tCelerity X5 Full Carbon Fiber Fork, Internal Brake Routing, 12x100mm Thru-Axle \n\nWheels \nWheelset\tAlexRims CXD7 Wheelset \nTire\tSchwalbe Durano Plus 700x28mm \nInner Tubes\tSchwalbe SV15 700x18-28mm \nSkewers\tCelerity X5 Thru-Axle Skewers \n\nDrivetrain \nShifter\tShimano 105 R7025 Hydraulic Disc Shifters \nFront Derailleur\tShimano 105 R7000 \nRear Derailleur\tShimano 105 R7000 \nCrankset\tShimano 105 R7000 50-34T \nBottom Bracket\tShimano BB72-41B \nCassette\tShimano 105 R7000 11-30T \nChain\tShimano HG601 11-Speed Chain \n\nComponents \nSaddle\tSelle Royal Asphalt Saddle \nSeatpost\tCelerity X5 Carbon Seatpost \nHandlebar\tCelerity X5 Compact Handlebar \nStem\tCelerity X5 Aluminum Stem \nHeadset\tFSA Orbit IS-2 \n\nBrakes \nBrakes\tShimano 105 R7025 Hydraulic Disc Brakes \nRotors\tShimano SM-RT70 160mm Rotors \n\nAccessories \nPedals\tCelerity X5 Road Pedals \n\nWeight \nWeight\t8.2 kg / 18.1 lbs \nWeight Limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 120 kg (265 lbs).\n\n## Sizing \n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 49 | 155 - 162 cm 5'1\" - 5'4\" | 71 - 76 cm 28\" - 30\" |\n| 52 | 162 - 170 cm 5'4\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| 54 | 170 - 178 cm 5'7\" - 5'10\" | 77 - 83 cm 30\" - 32\" |\n| 56 | 178 - 185 cm 5'10\" - 6'1\" | 82 - 88 cm 32\" - 34\" |\n| 58 | 185 - 193 cm 6'1\" - 6'4\" | 86 - 92 cm 34\" - 36\" |\n| 61 | 193 - 200 cm 6'4\" - 6'7\" | 90 - 95 cm 35\" - 37\" |\n\n## Geometry \n| Frame size number | 49 cm | 52 cm | 54 cm | 56 cm | 58 cm | 61 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 47.5 | 50.0 | 52.0 | 54.0 | 56.0 | 58.5 |\n| B — Seat tube angle | 75.0° | 74.5° | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 12.0 | 14.5 | 16.5 | 18.5 | 20.5 | 23.5 |\n| D — Head angle | 70.0° | 71.0° | 71.5° | 72.0° | 72.5° | 72.5° |\n| E — Effective top tube | 52.5 | 53.5 | 54.5 | 56.0 | 57.5 | 59.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 |\n| K — Wheelbase | 98.4 | 98.9 | 99.8 | 100.8 | 101.7 | 103.6 |\n| L — Standover | 72.0 | 74.0 | 76.0 | 78.0 | 80.0 | 82.0 |\n| M — Frame reach | 36.2 | 36.8 | 37.3 | 38.1 | 38.6 | 39.4 |\n| N — Frame stack | 52.0 | 54.3 | 56.2 | 58.1 | 59.8 | 62.4 |\n| Saddle rail height min | 67.0 | 69.5 | 71.5 | 74.0 | 76.0 | 78.0 |\n| Saddle rail height max | 75.0 | 77.5 | 79.5 | 82.0 | 84.0 | 86.0 |", - "price": 399.99, - "tags": [ - "bicycle", - "city bike" - ] - }, - { - "name": "Velocity V8", - "shortDescription": "Velocity V8 is a high-performance road bike that is designed to deliver speed, agility, and control on the road. With its lightweight aluminum frame, carbon fiber fork, Shimano Tiagra groupset, and hydraulic disc brakes, this bike is perfect for experienced riders who are looking for a fast and responsive bike that can handle various road conditions.", - "description": "## Overview\n\nIt's right for you if... \nYou are an experienced rider who is looking for a high-performance road bike that is lightweight, agile, and responsive. You want a bike that can handle long-distance rides, steep climbs, and fast descents with ease. You also appreciate the latest technology and reliable components that make your riding experience more enjoyable.\n\nThe tech you get \nVelocity V8 features a lightweight aluminum frame with a carbon fiber fork that ensures a comfortable ride without sacrificing stiffness and power transfer. It comes with a Shimano Tiagra groupset with 10-speed gearing for precise and efficient shifting. Hydraulic disc brakes offer superior stopping power in all weather conditions, while 28mm wide tires provide comfort and stability on various road surfaces. Internal cable routing enhances the bike's sleek appearance.\n\nThe final word \nIf you are looking for a high-performance road bike that is lightweight, fast, and responsive, Velocity V8 is the perfect choice. With its lightweight aluminum frame, reliable components, and advanced technology, this bike is designed to help you enjoy fast and comfortable rides on the road.\n\n## Features \n\nLightweight and responsive \nVelocity V8 comes with a lightweight aluminum frame that is not only lightweight but also responsive, providing excellent handling and control.\n\nHydraulic disc brakes \nThis bike is equipped with hydraulic disc brakes that provide superior stopping power in all weather conditions, ensuring your safety and confidence on the road.\n\nComfortable rides \nThe 28mm wide tires and carbon fork provide ample cushioning, ensuring a smooth and comfortable ride over long distances.\n\nSleek appearance \nThe bike's internal cable routing enhances its sleek appearance while also protecting the cables from the elements, ensuring smooth shifting for longer periods.\n\n## Specifications \n\nFrameset \nFrame\tVelocity V8 Aluminum Frame, Internal Cable Routing, Tapered Headtube, Press Fit Bottom Bracket, 12x142mm Thru-Axle \nFork\tVelocity V8 Carbon Fiber Fork, Internal Brake Routing, 12x100mm Thru-Axle \n\nWheels \nWheelset\tAlexRims CXD7 Wheelset \nTire\tSchwalbe Durano Plus 700x28mm \nInner Tubes\tSchwalbe SV15 700x18-28mm \nSkewers\tVelocity V8 Thru-Axle Skewers \n\nDrivetrain \nShifter\tShimano Tiagra Hydraulic Disc Shifters \nFront Derailleur\tShimano Tiagra \nRear Derailleur\tShimano Tiagra \nCrankset\tShimano Tiagra 50-34T \nBottom Bracket\tShimano BB-RS500-PB \nCassette\tShimano Tiagra 11-32T \nChain\tShimano HG54 10-Speed Chain \n\nComponents \nSaddle\tVelocity V8 Saddle \nSeatpost\tVelocity V8 Aluminum Seatpost \nHandlebar\tVelocity V8 Compact Handlebar \nStem\tVelocity V8 Aluminum Stem \nHeadset\tFSA Orbit IS-2 \n\nBrakes \nBrakes\tShimano Tiagra Hydraulic Disc Brakes \nRotors\tShimano SM-RT64 160mm Rotors \n\nAccessories \nPedals\tVelocity V8 Road Pedals \n\nWeight \nWeight\t9.4 kg / 20.7 lbs \nWeight Limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 120 kg (265 lbs).\n\n## Sizing \n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 49 | 155 - 162 cm 5'1\" - 5'4\" | 71 - 76 cm 28\" - 30\" |\n| 52 | 162 - 170 cm 5'4\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| 54 | 170 - 178 cm 5'7\" - 5'10\" | 77 - 83 cm 30\" - 32\" |\n| 56 | 178 - 185 cm 5'10\" - 6'1\" | 82 - 88 cm 32\" - 34\" |\n| 58 | 185 - 193 cm 6'1\" - 6'4\" | 86 - 92 cm 34\" - 36\" |\n| 61 | 193 - 200 cm 6'4\" - 6'7\" | 90 - 95 cm 35\" - 37\" |\n\n## Geometry \n| Frame size number | 49 cm | 52 cm | 54 cm | 56 cm | 58 cm | 61 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 47.5 | 50.0 | 52.0 | 54.0 | 56.0 | 58.5 |\n| B — Seat tube angle | 75.0° | 74.5° | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 12.0 | 14.5 | 16.5 | 18.5 | 20.5 | 23.5 |\n| D — Head angle | 70.0° | 71.0° | 71.5° | 72.0° | 72.5° | 72.5° |\n| E — Effective top tube | 52.5 | 53.5 | 54.5 | 56.0 | 57.5 | 59.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 |\n| K — Wheelbase | 98.4 | 98.9 | 99.8 | 100.8 | 101.7 | 103.6 |\n| L — Standover | 72.0 | 74.0 | 76.0 | 78.0 | 80.0 | 82.0 |\n| M — Frame reach | 36.2 | 36.8 | 37.3 | 38.1 | 38.6 | 39.4 |\n| N — Frame stack | 52.0 | 54.3 | 56.2 | 58.1 | 59.8 | 62.4 |\n| Saddle rail height min | 67.0 | 69.5 | 71.5 | 74.0 | 76.0 | 78.0 |\n| Saddle rail height max | 75.0 | 77.5 | 79.5 | 82.0 | 84.0 | 86.0 |", - "price": 1899.99, - "tags": [ - "bicycle", - "electric bike" - ] - }, - { - "name": "VeloCore X9 eMTB", - "shortDescription": "The VeloCore X9 eMTB is a light, agile and versatile electric mountain bike designed for adventure and performance. Its purpose-built frame and premium components offer an exhilarating ride experience on both technical terrain and smooth singletrack.", - "description": "## Overview\nIt's right for you if...\nYou love exploring new trails and testing your limits on challenging terrain. You want an electric mountain bike that offers power when you need it, without sacrificing performance or agility. You're looking for a high-quality bike with top-notch components and a sleek design.\n\nThe tech you get\nA lightweight, full carbon frame with custom geometry, a 140mm RockShox Pike Ultimate fork with Charger 2.1 damper, and a Fox Float DPS Performance shock. A Shimano STEPS E8000 motor and 504Wh battery that provide up to 62 miles of range and 20 mph assistance. A Shimano XT 12-speed drivetrain, Shimano SLX brakes, and DT Swiss wheels.\n\nThe final word\nThe VeloCore X9 eMTB delivers power and agility in equal measure. It's a versatile and capable electric mountain bike that can handle any trail with ease. With premium components, a custom carbon frame, and a sleek design, this bike is built for adventure.\n\n## Features\nAgile and responsive\n\nThe VeloCore X9 eMTB is designed to be nimble and responsive on the trail. Its custom carbon frame offers a perfect balance of stiffness and compliance, while the suspension system provides smooth and stable performance on technical terrain.\n\nPowerful and efficient\n\nThe Shimano STEPS E8000 motor and 504Wh battery provide up to 62 miles of range and 20 mph assistance. The motor delivers smooth and powerful performance, while the battery offers reliable and consistent power for long rides.\n\nCustomizable ride experience\n\nThe VeloCore X9 eMTB comes with an intuitive and customizable Shimano STEPS display that allows you to adjust the level of assistance, monitor your speed and battery life, and customize your ride experience to suit your needs.\n\nPremium components\n\nThe VeloCore X9 eMTB is equipped with high-end components, including a Shimano XT 12-speed drivetrain, Shimano SLX brakes, and DT Swiss wheels. These components offer reliable and precise performance, allowing you to push your limits with confidence.\n\n## Specs\nFrameset\nFrame\tVeloCore carbon fiber frame, Boost, tapered head tube, internal cable routing, 140mm travel\nFork\tRockShox Pike Ultimate, Charger 2.1 damper, DebonAir spring, 15x110mm Boost Maxle Ultimate, 46mm offset, 140mm travel\nShock\tFox Float DPS Performance, EVOL, 3-position adjust, Kashima Coat, 210x50mm\n\nWheels\nWheel front\tDT Swiss XM1700 Spline, 30mm internal width, 15x110mm Boost axle\nWheel rear\tDT Swiss XM1700 Spline, 30mm internal width, Shimano Microspline driver, 12x148mm Boost axle\nTire front\tMaxxis Minion DHF, 29x2.5\", EXO+ casing, tubeless ready\nTire rear\tMaxxis Minion DHR II, 29x2.4\", EXO+ casing, tubeless ready\n\nDrivetrain\nShifter\tShimano XT M8100, 12-speed\nRear derailleur\tShimano XT M8100, Shadow Plus, long cage, 51T max cog\nCrankset\tShimano STEPS E8000, 165mm length, 34T chainring\nCassette\tShimano XT M8100, 10-51T, 12-speed\nChain\tShimano CN-M8100, 12-speed\nPedals\tNot included\n\nComponents\nSaddle\tBontrager Arvada, hollow chromoly rails\nSeatpost\tDrop Line, internal routing, 31.6mm (15.5: 100mm, 17.5 & 18.5: 125mm, 19.5 & 21.5: 150mm)\nHandlebar\tBontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\nStem\tBontrager Line Pro, 35mm, Knock Block, 0 degree, 50mm length\nGrips\tBontrager XR Trail Elite, alloy lock-on\nHeadset\tIntegrated, sealed cartridge bearing, 1-1/8\" top, 1.5\" bottom\nBrakeset\tShimano SLX M7120, 4-piston hydraulic disc\n\nAccessories\nBattery\tShimano STEPS BT-E8010, 504Wh\nCharger\tShimano STEPS EC-E8004, 4A\nController\tShimano STEPS E8000 display\nBike weight\tM - 22.5 kg / 49.6 lbs (with tubes)\n\n## Sizing & fit\n\n| Size | Rider Height |\n|:----:|:------------------------:|\n| S | 162 - 170 cm 5'4\" - 5'7\" |\n| M | 170 - 178 cm 5'7\" - 5'10\"|\n| L | 178 - 186 cm 5'10\" - 6'1\"|\n| XL | 186 - 196 cm 6'1\" - 6'5\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\n| Frame size | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| A — Seat tube | 40.6 | 43.2 | 47.0 | 51.0 |\n| B — Seat tube angle | 75.0° | 75.0° | 75.0° | 75.0° |\n| C — Head tube length | 9.6 | 10.6 | 11.6 | 12.6 |\n| D — Head angle | 66.5° | 66.5° | 66.5° | 66.5° |\n| E — Effective top tube | 60.4 | 62.6 | 64.8 | 66.9 |\n| F — Bottom bracket height | 33.2 | 33.2 | 33.2 | 33.2 |\n| G — Bottom bracket drop | 3.0 | 3.0 | 3.0 | 3.0 |\n| H — Chainstay length | 45.5 | 45.5 | 45.5 | 45.5 |\n| I — Offset | 4.6 | 4.6 | 4.6 | 4.6 |\n| J — Trail | 11.9 | 11.9 | 11.9 | 11.9 |\n| K — Wheelbase | 117.0 | 119.3 | 121.6 | 123.9 |\n| L — Standover | 75.9 | 75.9 | 78.6 | 78.6 |\n| M — Frame reach | 43.6 | 45.6 | 47.6 | 49.6 |\n| N — Frame stack | 60.5 | 61.5 | 62.4 | 63.4 |", - "price": 1299.99, - "tags": [ - "bicycle", - "touring bike" - ] - }, - { - "name": "Zephyr 8.8 GX Eagle AXS Gen 3", - "shortDescription": "Zephyr 8.8 GX Eagle AXS is a light and nimble full-suspension mountain bike. It's designed to handle technical terrain with ease and has a smooth and efficient ride feel. The sleek and powerful Bosch Performance Line CX motor and removable Powertube battery provide a boost to your pedaling and give you long-lasting riding time. The bike also features high-end components and advanced technology for an ultimate mountain biking experience.", - "description": "## Overview\nIt's right for you if...\nYou're an avid mountain biker looking for a high-performance e-MTB that can tackle challenging trails. You want a bike with a powerful motor, efficient suspension, and advanced technology to enhance your riding experience. You also need a bike that's reliable and durable for long-lasting use.\n\nThe tech you get\nA lightweight, full carbon frame with 150mm of rear travel and a 160mm RockShox Pike Ultimate fork with Charger 2.1 RCT3 damper, remote lockout, and DebonAir spring. A Bosch Performance Line CX motor and removable Powertube 625Wh battery that can assist up to 20mph when it's on and gives zero drag when it's off, plus an easy-to-use handlebar-mounted Bosch Purion controller. A SRAM GX Eagle AXS wireless electronic drivetrain, a RockShox Reverb Stealth dropper, and DT Swiss HX1501 Spline One wheels.\n\nThe final word\nZephyr 8.8 GX Eagle AXS is a high-performance e-MTB that's designed to handle technical terrain with ease. With a powerful Bosch motor and long-lasting battery, you can conquer challenging climbs and enjoy long rides. The bike also features high-end components and advanced technology for an ultimate mountain biking experience.\n\n## Features\nPowerful motor\n\nThe Bosch Performance Line CX motor provides a boost to your pedaling and can assist up to 20mph. It has four power modes and a walk-assist function for easy navigation on steep climbs. The motor is also reliable and durable for long-lasting use.\n\nEfficient suspension\n\nZephyr 8.8 has a 150mm of rear travel and a 160mm RockShox Pike Ultimate fork with Charger 2.1 RCT3 damper, remote lockout, and DebonAir spring. The suspension is efficient and responsive, allowing you to handle technical terrain with ease.\n\nRemovable battery\n\nThe Powertube 625Wh battery is removable for easy charging and storage. It provides long-lasting riding time and can be replaced with a spare battery for even longer rides. The battery is also durable and weather-resistant for all-season riding.\n\nAdvanced technology\n\nZephyr 8.8 is equipped with advanced technology, including a Bosch Purion controller for easy motor control, a SRAM GX Eagle AXS wireless electronic drivetrain for precise shifting, and a RockShox Reverb Stealth dropper for adjustable saddle height. The bike also has DT Swiss HX1501 Spline One wheels for reliable performance on any terrain.\n\nCarbon frame\n\nThe full carbon frame is lightweight and durable, providing a smooth and efficient ride. It's also designed with a tapered head tube, internal cable routing, and Boost148 spacing for enhanced stiffness and responsiveness.\n\n## Specs\nFrameset\nFrame\tCarbon main frame & stays, tapered head tube, internal routing, Boost148, 150mm travel\nFork\tRockShox Pike Ultimate, Charger 2.1 RCT3 damper, DebonAir spring, remote lockout, tapered steerer, Boost110, 15mm Maxle Stealth, 160mm travel\nShock\tRockShox Deluxe RT3, DebonAir spring, 205mm x 57.5mm\nMax compatible fork travel\t170mm\n\nWheels\nWheel front\tDT Swiss HX1501 Spline One, Centerlock, 30mm inner width, 110x15mm Boost\nWheel rear\tDT Swiss HX1501 Spline One, Centerlock, 30mm inner width, SRAM XD driver, 148x12mm Boost\nTire\tBontrager XR4 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.40''\nMax tire size\t29x2.60\"\n\nDrivetrain\nShifter\tSRAM GX Eagle AXS, wireless, 12 speed\nRear derailleur\tSRAM GX Eagle AXS\nCrank\tBosch Gen 4, 32T\nChainring\tSRAM X-Sync 2, 32T, direct-mount\nCassette\tSRAM PG-1275 Eagle, 10-52, 12 speed\nChain\tSRAM GX Eagle, 12 speed\n\nComponents\nSaddle\tBontrager Arvada, hollow titanium rails, 138mm width\nSeatpost\tRockShox Reverb Stealth, 31.6mm, internal routing, 150mm (S), 170mm (M/L), 200mm (XL)\nHandlebar\tBontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\nGrips\tBontrager XR Trail Elite, alloy lock-on\nStem\tBontrager Line Pro, Knock Block, 35mm, 0 degree, 50mm length\nHeadset\tIntegrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\nBrake\tSRAM Code RSC hydraulic disc, 200mm (front), 180mm (rear)\nBrake rotor\tSRAM CenterLine, centerlock, round edge, 200mm (front), 180mm (rear)\n\nAccessories\nE-bike system\tBosch Performance Line CX\nBattery\tBosch Powertube 625Wh\nCharger\tBosch 4A compact charger\nController\tBosch Purion\nTool\tBontrager multi-tool, integrated storage bag\n\nWeight\nWeight\tM - 24.08 kg / 53.07 lbs (with TLR sealant, no tubes)\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n\n## Sizing & fit\n\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 153 - 162 cm 5'0\" - 5'4\" | 67 - 74 cm 26\" - 29\" |\n| M | 161 - 172 cm 5'3\" - 5'8\" | 74 - 79 cm 29\" - 31\" |\n| L | 171 - 180 cm 5'7\" - 5'11\" | 79 - 84 cm 31\" - 33\" |\n| XL | 179 - 188 cm 5'10\" - 6'2\" | 84 - 89 cm 33\" - 35\" |\n\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 15.5 | 17.5 | 19.5 | 21.5 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.4 | 41.9 | 44.5 | 47.6 |\n| B — Seat tube angle | 76.1° | 76.1° | 76.1° | 76.1° |\n| C — Head tube length | 9.6 | 10.5 | 11.5 | 12.5 |\n| D — Head angle | 65.5° | 65.5° | 65.5° | 65.5° |\n| E — Effective top tube | 58.6 | 61.3 | 64.0 | 66.7 |\n| F — Bottom bracket height | 34.0 | 34.0 | 34.0 | 34.0 |\n| G — Bottom bracket drop | 1.0 | 1.0 | 1.0 | 1.0 |\n| H — Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\n| I — Offset | 4.6 | 4.6 | 4.6 | 4.6 |\n| J — Trail | 10.5 | 10.5 | 10.5 | 10.5 |\n| K — Wheelbase | 119.5 | 122.3 | 125.0 | 127.8 |\n| L — Standover | 72.7 | 74.7 | 77.6 | 81.0 |\n|", - "price": 1499.99, - "tags": [ - "bicycle", - "electric bike", - "city bike" - ] - }, - { - "name": "Velo 99 XR1 AXS", - "shortDescription": "Velo 99 XR1 AXS is a next-generation bike designed for fast-paced adventure seekers and speed enthusiasts. Built for high-performance racing, the bike boasts state-of-the-art technology and premium components. It is the ultimate bike for riders who want to push their limits and get their adrenaline pumping.", - "description": "## Overview\nIt's right for you if...\nYou are a passionate cyclist looking for a bike that can keep up with your speed, agility, and endurance. You are an adventurer who loves to explore new terrains and challenge yourself on the toughest courses. You want a bike that is lightweight, durable, and packed with the latest technology.\n\nThe tech you get\nA lightweight, full carbon frame with advanced aerodynamics and integrated cable routing for a clean look. A high-performance SRAM XX1 Eagle AXS wireless electronic drivetrain, featuring a 12-speed cassette and a 32T chainring. A RockShox SID Ultimate fork with a remote lockout, 120mm travel, and Charger Race Day damper. A high-end SRAM G2 Ultimate hydraulic disc brake with carbon levers. A FOX Transfer SL dropper post for quick and easy height adjustments. DT Swiss XRC 1501 carbon wheels for superior speed and handling.\n\nThe final word\nVelo 99 XR1 AXS is a premium racing bike that can help you achieve your goals and reach new heights. It is designed for speed, agility, and performance, and it is packed with the latest technology and premium components. If you are a serious cyclist who wants the best, this is the bike for you.\n\n## Features\nAerodynamic design\n\nThe Velo 99 XR1 AXS features a state-of-the-art frame design that reduces drag and improves speed. It has an aerodynamic seatpost, integrated cable routing, and a sleek, streamlined look that sets it apart from other bikes.\n\nWireless electronic drivetrain\n\nThe SRAM XX1 Eagle AXS drivetrain features a wireless electronic system that provides precise, instant shifting and unmatched efficiency. It eliminates the need for cables and makes the bike lighter and faster.\n\nHigh-performance suspension\n\nThe RockShox SID Ultimate fork and Charger Race Day damper provide 120mm of smooth, responsive suspension that can handle any terrain. The fork also has a remote lockout for quick adjustments on the fly.\n\nSuperior braking power\n\nThe SRAM G2 Ultimate hydraulic disc brake system delivers unmatched stopping power and control. It has carbon levers for a lightweight, ergonomic design and precision control.\n\nCarbon wheels\n\nThe DT Swiss XRC 1501 carbon wheels are ultra-lightweight, yet incredibly strong and durable. They provide superior speed and handling, making the bike more agile and responsive.\n\n## Specs\nFrameset\nFrame\tFull carbon frame, integrated cable routing, aerodynamic design, Boost148\nFork\tRockShox SID Ultimate, Charger Race Day damper, remote lockout, tapered steerer, Boost110, 15mm Maxle Stealth, 120mm travel\n\nWheels\nWheel front\tDT Swiss XRC 1501 carbon wheel, Boost110, 15mm thru axle\nWheel rear\tDT Swiss XRC 1501 carbon wheel, SRAM XD driver, Boost148, 12mm thru axle\nTire\tSchwalbe Racing Ray, Performance Line, Addix, 29x2.25\"\nTire part\tSchwalbe Doc Blue Professional, 500ml\nMax tire size\t29x2.3\"\n\nDrivetrain\nShifter\tSRAM Eagle AXS, wireless, 12-speed\nRear derailleur\tSRAM XX1 Eagle AXS\nCrank\tSRAM XX1 Eagle, 32T, carbon\nChainring\tSRAM X-SYNC, 32T, alloy\nCassette\tSRAM Eagle XG-1299, 10-52, 12-speed\nChain\tSRAM XX1 Eagle, 12-speed\nMax chainring size\t1x: 32T\n\nComponents\nSaddle\tBontrager Montrose Elite, carbon rails, 138mm width\nSeatpost\tFOX Transfer SL, 125mm travel, internal routing, 31.6mm\nHandlebar\tBontrager Kovee Pro, ADV Carbon, 35mm, 5mm rise, 720mm width\nGrips\tBontrager XR Endurance Elite\nStem\tBontrager Kovee Pro, 35mm, Blendr compatible, 7 degree, 60mm length\nHeadset\tIntegrated, cartridge bearing, 1-1/8\" top, 1.5\" bottom\nBrake\tSRAM G2 Ultimate hydraulic disc, carbon levers, 180mm rotors\n\nAccessories\nBike computer\tBontrager Trip 300\nTool\tBontrager Flatline Pro pedal wrench, T25 Torx\n\n\n## Sizing & fit\n\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 158 - 168 cm 5'2\" - 5'6\" | 74 - 78 cm 29\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 78 - 82 cm 31\" - 32\" |\n| L | 173 - 183 cm 5'8\" - 6'0\" | 82 - 86 cm 32\" - 34\" |\n| XL | 180 - 193 cm 5'11\" - 6'4\" | 86 - 90 cm 34\" - 35\" |\n\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 15.5 | 17.5 | 19.5 | 21.5 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.9 | 43.0 | 47.0 | 51.0 |\n| B — Seat tube angle | 74.5° | 74.5° | 74.5° | 74.5° |\n| C — Head tube length | 9.0 | 10.0 | 11.0 | 12.0 |\n| D — Head angle | 68.0° | 68.0° | 68.0° | 68.0° |\n| E — Effective top tube | 57.8 | 59.7 | 61.6 | 63.6 |\n| F — Bottom bracket height | 33.0 | 33.0 | 33.0 | 33.0 |\n| G — Bottom bracket drop | 5.0 | 5.0 | 5.0 | 5.0 |\n| H — Chainstay length | 43.0 | 43.0 | 43.0 | 43.0 |\n| I — Offset | 4.2 | 4.2 | 4.2 | 4.2 |\n| J — Trail | 9.7 | 9.7 | 9.7 | 9.7 |\n| K — Wheelbase | 112.5 | 114.5 | 116.5 | 118.6 |\n| L — Standover | 75.9 | 77.8 | 81.5 | 84.2 |\n| M — Frame reach | 41.6 | 43.4 | 45.2 | 47.1 |\n| N — Frame stack | 58.2 | 58.9 | 59.3 | 59.9 |", - "price": 1099.99, - "tags": [ - "bicycle", - "mountain bike" - ] - }, - { - "name": "AURORA 11S E-MTB", - "shortDescription": "The AURORA 11S is a powerful and stylish electric mountain bike designed to take you on thrilling off-road adventures. With its sturdy frame and premium components, this bike is built to handle any terrain. It features a high-performance motor, long-lasting battery, and advanced suspension system that guarantee a smooth and comfortable ride.", - "description": "## Overview\nIt's right for you if...\nYou want a top-of-the-line e-MTB that is both powerful and stylish. You also want a bike that can handle any terrain, from steep climbs to rocky descents. With its advanced features and premium components, the AURORA 11S is designed for serious off-road riders who demand the best.\n\nThe tech you get\nA sturdy aluminum frame with advanced suspension system that provides 120mm of travel. A 750W brushless motor that delivers up to 28mph, and a 48V/14Ah lithium-ion battery that provides up to 60 miles of range on a single charge. An advanced 11-speed Shimano drivetrain with hydraulic disc brakes for precise shifting and reliable stopping power. \n\nThe final word\nThe AURORA 11S is a top-of-the-line e-MTB that delivers exceptional performance and style. Whether you're tackling steep climbs or hitting rocky descents, this bike is built to handle any terrain with ease. With its advanced features and premium components, the AURORA 11S is the perfect choice for serious off-road riders who demand the best.\n\n## Features\nPowerful and efficient\n\nThe AURORA 11S is equipped with a high-performance 750W brushless motor that delivers up to 28mph. The motor is powered by a long-lasting 48V/14Ah lithium-ion battery that provides up to 60 miles of range on a single charge.\n\nAdvanced suspension system\n\nThe bike's advanced suspension system provides 120mm of travel, ensuring a smooth and comfortable ride on any terrain. The front suspension is a Suntour XCR32 Air fork, while the rear suspension is a KS-281 hydraulic shock absorber.\n\nPremium components\n\nThe AURORA 11S features an advanced 11-speed Shimano drivetrain with hydraulic disc brakes. The bike is also equipped with a Tektro HD-E725 hydraulic disc brake system that provides reliable stopping power.\n\nSleek and stylish design\n\nWith its sleek and stylish design, the AURORA 11S is sure to turn heads on the trail. The bike's sturdy aluminum frame is available in a range of colors, including black, blue, and red.\n\n## Specs\nFrameset\nFrame Material: Aluminum\nFrame Size: S, M, L\nFork: Suntour XCR32 Air, 120mm Travel\nShock Absorber: KS-281 Hydraulic Shock Absorber\n\nWheels\nWheel Size: 27.5 inches\nTires: Kenda K1151 Nevegal, 27.5x2.35\nRims: Alloy Double Wall\nSpokes: 32H, Stainless Steel\n\nDrivetrain\nShifters: Shimano SL-M7000\nRear Derailleur: Shimano RD-M8000\nCrankset: Prowheel 42T, Alloy Crank Arm\nCassette: Shimano CS-M7000, 11-42T\nChain: KMC X11EPT\n\nBrakes\nBrake System: Tektro HD-E725 Hydraulic Disc Brake\nBrake Rotors: 180mm Front, 160mm Rear\n\nE-bike system\nMotor: 750W Brushless\nBattery: 48V/14Ah Lithium-Ion\nCharger: 48V/3A Smart Charger\nController: Intelligent Sinusoidal Wave\n\nWeight\nWeight: 59.5 lbs\n\n## Sizing & fit\n| Size | Rider Height | Standover Height |\n|------|-------------|-----------------|\n| S | 5'2\"-5'6\" | 28.5\" |\n| M | 5'7\"-6'0\" | 29.5\" |\n| L | 6'0\"-6'4\" | 30.5\" |\n\n## Geometry\nAll measurements provided in cm.\nSizing table\n| Frame size letter | S | M | L |\n|-------------------|-----|-----|-----|\n| Wheel Size | 27.5\"| 27.5\"| 27.5\"|\n| Seat tube length | 44.5| 48.5| 52.5|\n| Head tube angle | 68° | 68° | 68° |\n| Seat tube angle | 74.5°| 74.5°| 74.5°|\n| Effective top tube | 57.5| 59.5| 61.5|\n| Head tube length | 12.0| 12.0| 13.0|\n| Chainstay length | 45.5| 45.5| 45.5|\n| Bottom bracket height | 30.0| 30.0| 30.0|\n| Wheelbase | 115.0|116.5|118.5|", - "price": 1999.99, - "tags": [ - "bicycle", - "road bike" - ] - }, - { - "name": "VeloTech V9.5 AXS Gen 3", - "shortDescription": "VeloTech V9.5 AXS is a sleek and fast carbon bike that combines high-end tech with a comfortable ride. It's designed to provide the ultimate experience for the most serious riders. The bike comes with a lightweight and powerful motor that can be activated when needed, and you get a spec filled with premium parts.", - "description": "## Overview\nIt's right for you if...\nYou want a bike that is fast, efficient, and delivers an adrenaline-filled experience. You are looking for a bike that is built with cutting-edge technology, and you want a ride that is both comfortable and exciting.\n\nThe tech you get\nA lightweight and durable full carbon frame with a fork that has 100mm of travel. The bike comes with a powerful motor that can deliver up to 20 mph of assistance. The drivetrain is a wireless electronic system that is precise and reliable. The bike is also equipped with hydraulic disc brakes, tubeless-ready wheels, and comfortable grips.\n\nThe final word\nThe VeloTech V9.5 AXS is a high-end bike that delivers an incredible experience for serious riders. It combines the latest technology with a comfortable ride, making it perfect for long rides, tough climbs, and fast descents.\n\n## Features\nFast and efficient\nThe VeloTech V9.5 AXS comes with a powerful motor that can provide up to 20 mph of assistance. The motor is lightweight and efficient, providing a boost when you need it without adding bulk. The bike's battery is removable, allowing you to ride without assistance when you don't need it.\n\nSmart software for the trail\nThe VeloTech V9.5 AXS is equipped with intelligent software that delivers a smooth and responsive ride. The software allows the motor to respond immediately as you start to pedal, delivering more power over a wider cadence range. You can also customize your user settings to suit your preferences.\n\nComfortable ride\nThe VeloTech V9.5 AXS is designed to provide a comfortable ride, even on long rides. The bike's fork has 100mm of travel, providing ample cushioning for rough terrain. The bike's grips are also designed to provide a comfortable and secure grip, even on the most challenging rides.\n\n## Specs\nFrameset\nFrame\tCarbon fiber frame with internal cable routing and Boost148\nFork\t100mm of travel with remote lockout\nShock\tN/A\n\nWheels\nWheel front\tCarbon fiber tubeless-ready wheel\nWheel rear\tCarbon fiber tubeless-ready wheel\nSkewer rear\t12mm thru-axle\nTire\tTubeless-ready tire\nTire part\tTubeless sealant\n\nDrivetrain\nShifter\tWireless electronic shifter\nRear derailleur\tWireless electronic derailleur\nCrank\tCarbon fiber crankset with chainring\nCrank arm\tCarbon fiber crank arm\nChainring\tAlloy chainring\nCassette\t12-speed cassette\nChain\t12-speed chain\n\nComponents\nSaddle\tCarbon fiber saddle\nSeatpost\tCarbon fiber seatpost\nHandlebar\tCarbon fiber handlebar\nGrips\tComfortable and secure grips\nStem\tCarbon fiber stem\nHeadset\tCarbon fiber headset\nBrake\tHydraulic disc brakes\nBrake rotor\tDisc brake rotor\n\nAccessories\nE-bike system\tPowerful motor with removable battery\nBattery\tLithium-ion battery\nCharger\tFast charging adapter\nController\tHandlebar-mounted controller\nTool\tBasic toolkit\n\nWeight\nWeight\tM - 17.5 kg / 38.5 lbs (with tubeless sealant)\n\nWeight limit\nThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing & fit\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 160 - 170 cm 5'3\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| M | 170 - 180 cm 5'7\" - 5'11\" | 79 - 84 cm 31\" - 33\" |\n| L | 180 - 190 cm 5'11\" - 6'3\" | 84 - 89 cm 33\" - 35\" |\n| XL | 190 - 200 cm 6'3\" - 6'7\" | 89 - 94 cm 35\" - 37\" |\n\n## Geometry\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 50.0 | 53.3 | 55.6 | 58.8 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.4 | 43.2 | 48.3 | 53.3 |\n| B — Seat tube angle | 72.3° | 72.6° | 72.8° | 72.8° |\n| C — Head tube length | 9.0 | 10.0 | 10.5 | 11.0 |\n| D — Head angle | 67.5° | 67.5° | 67.5° | 67.5° |\n| E — Effective top tube | 58.0 | 61.7 | 64.8 | 67.0 |\n| F — Bottom bracket height | 32.3 | 32.3 | 32.3 | 32.3 |\n| G — Bottom bracket drop | 5.0 | 5.0 | 5.0 | 5.0 |\n| H — Chainstay length | 44.7 | 44.7 | 44.7 | 44.7 |\n| I — Offset | 4.2 | 4.2 | 4.2 | 4.2 |\n| J — Trail | 10.9 | 10.9 | 10.9 | 10.9 |\n| K — Wheelbase | 112.6 | 116.5 | 119.7 | 121.9 |\n| L — Standover | 76.8 | 76.8 | 76.8 | 76.8 |\n| M — Frame reach | 40.5 | 44.0 | 47.0 | 49.0 |\n| N — Frame stack | 60.9 | 61.8 | 62.2 | 62.7 |", - "price": 1699.99, - "tags": [ - "bicycle", - "electric bike", - "city bike" - ] - }, - { - "name": "Axiom D8 E-Mountain Bike", - "shortDescription": "The Axiom D8 is an electrifying mountain bike that is built for adventure. It boasts a light aluminum frame, a powerful motor and the latest tech to tackle the toughest of terrains. The D8 provides assistance without adding bulk to the bike, giving you the flexibility to ride like a traditional mountain bike or have an extra push when you need it.", - "description": "## Overview \nIt's right for you if... \nYou're looking for an electric mountain bike that can handle a wide variety of terrain, from flowing singletrack to technical descents. You also want a bike that offers a powerful motor that provides assistance without adding bulk to the bike. The D8 is designed to take you anywhere, quickly and comfortably.\n\nThe tech you get \nA lightweight aluminum frame with 140mm of travel, a Suntour fork with hydraulic lockout, and a reliable and powerful Bafang M400 mid-motor that provides a boost up to 20 mph. The bike features a Shimano Deore drivetrain, hydraulic disc brakes, and a dropper seat post. With the latest tech on-board, the D8 is designed to take you to new heights.\n\nThe final word \nThe Axiom D8 is an outstanding electric mountain bike that is designed for adventure. It's built with the latest tech and provides the flexibility to ride like a traditional mountain bike or have an extra push when you need it. Whether you're a beginner or an experienced rider, the D8 is the perfect companion for your next adventure.\n\n## Features \nBuilt for Adventure \n\nThe D8 features a lightweight aluminum frame that is built to withstand rugged terrain. It comes equipped with 140mm of travel and a Suntour fork that can handle even the toughest of trails. With this bike, you're ready to take on anything the mountain can throw at you.\n\nPowerful Motor \n\nThe Bafang M400 mid-motor provides reliable and powerful assistance without adding bulk to the bike. You can quickly and easily switch between the different assistance levels to find the perfect balance between range and power.\n\nShimano Deore Drivetrain \n\nThe Shimano Deore drivetrain is reliable and offers smooth shifting on any terrain. You can easily adjust the gears to match your riding style and maximize your performance on the mountain.\n\nDropper Seat Post \n\nThe dropper seat post allows you to easily adjust your seat height on the fly, so you can maintain the perfect position for any terrain. With the flick of a switch, you can quickly and easily lower or raise your seat to match the terrain.\n\nHydraulic Disc Brakes \n\nThe D8 features powerful hydraulic disc brakes that offer reliable stopping power in any weather condition. You can ride with confidence knowing that you have the brakes to stop on a dime.\n\n## Specs \nFrameset \nFrame\tAluminum frame with 140mm of travel \nFork\tSuntour fork with hydraulic lockout, 140mm of travel \nShock\tN/A \nMax compatible fork travel\t140mm \n \nWheels \nWheel front\tAlloy wheel \nWheel rear\tAlloy wheel \nSkewer rear\tThru axle \nTire\t29\" x 2.35\" \nTire part\tN/A \nMax tire size\t29\" x 2.6\" \n \nDrivetrain \nShifter\tShimano Deore \nRear derailleur\tShimano Deore \nCrank\tBafang M400 \nCrank arm\tN/A \nChainring\tN/A \nCassette\tShimano Deore \nChain\tShimano Deore \nMax chainring size\tN/A \n \nComponents \nSaddle\tAxiom D8 saddle \nSeatpost\tDropper seat post \nHandlebar\tAxiom D8 handlebar \nGrips\tAxiom D8 grips \nStem\tAxiom D8 stem \nHeadset\tAxiom D8 headset \nBrake\tHydraulic disc brakes \nBrake rotor\t180mm \n\nAccessories \nE-bike system\tBafang M400 mid-motor \nBattery\tLithium-ion battery, 500Wh \nCharger\tLithium-ion charger \nController\tBafang M400 controller \nTool\tN/A \n \nWeight \nWeight\tM - 22 kg / 48.5 lbs \nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 136 kg (300 lbs). \n \n \n## Sizing & fit \n \n| Size | Rider Height | Inseam | \n|:----:|:------------------------:|:--------------------:| \n| S | 152 - 165 cm 5'0\" - 5'5\" | 70 - 76 cm 27\" - 30\" | \n| M | 165 - 178 cm 5'5\" - 5'10\" | 76 - 81 cm 30\" - 32\" | \n| L | 178 - 185 cm 5'10\" - 6'1\" | 81 - 86 cm 32\" - 34\" | \n| XL | 185 - 193 cm 6'1\" - 6'4\" | 86 - 91 cm 34\" - 36\" | \n \n \n## Geometry \n \nAll measurements provided in cm unless otherwise noted. \nSizing table \n| Frame size letter | S | M | L | XL | \n|---------------------------|-------|-------|-------|-------| \n| Actual frame size | 41.9 | 46.5 | 50.8 | 55.9 | \n| Wheel size | 29\" | 29\" | 29\" | 29\" | \n| A — Seat tube | 42.0 | 46.5 | 51.0 | 56.0 | \n| B — Seat tube angle | 74.0° | 74.0° | 74.0° | 74.0° | \n| C — Head tube length | 11.0 | 12.0 | 13.0 | 15.0 | \n| D — Head angle | 68.0° | 68.0° | 68.0° | 68.0° | \n| E — Effective top tube | 57.0 | 60.0 | 62.0 | 65.0 | \n| F — Bottom bracket height | 33.0 | 33.0 | 33.0 | 33.0 | \n| G — Bottom bracket drop | 3.0 | 3.0 | 3.0 | 3.0 | \n| H — Chainstay length | 46.0 | 46.0 | 46.0 | 46.0 | \n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | \n| J — Trail | 10.9 | 10.9 | 10.9 | 10.9 | \n| K — Wheelbase | 113.0 | 116.0 | 117.5 | 120.5 | \n| L — Standover | 73.5 | 75.5 | 76.5 | 79.5 | \n| M — Frame reach | 41.0 | 43.5 | 45.0 | 47.5 | \n| N — Frame stack | 60.5 | 61.5 | 62.5 | 64.5 |", - "price": 1399.99, - "tags": [ - "bicycle", - "electric bike", - "mountain bike" - ] - }, - { - "name": "Velocity X1", - "shortDescription": "Velocity X1 is a high-performance road bike designed for speed enthusiasts. It features a lightweight yet durable frame, aerodynamic design, and top-quality components, making it the perfect choice for those who want to take their cycling experience to the next level.", - "description": "## Overview\nIt's right for you if...\nYou're an experienced cyclist looking for a bike that can keep up with your need for speed. You want a bike that's lightweight, aerodynamic, and built to perform, whether you're training for a race or just pushing yourself to go faster.\n\nThe tech you get\nA lightweight aluminum frame with a carbon fork, Shimano Ultegra groupset with a wide range of gearing, hydraulic disc brakes, aerodynamic carbon wheels, and a vibration-absorbing handlebar with ergonomic grips.\n\nThe final word\nVelocity X1 is the ultimate road bike for speed enthusiasts. Its lightweight frame, aerodynamic design, and top-quality components make it the perfect choice for those who want to take their cycling experience to the next level.\n\n\n## Features\n\nAerodynamic design\nVelocity X1 is built with an aerodynamic design to help you go faster with less effort. It features a sleek profile, hidden cables, and a carbon fork that cuts through the wind, reducing drag and increasing speed.\n\nHydraulic disc brakes\nVelocity X1 comes equipped with hydraulic disc brakes, providing excellent stopping power in all weather conditions. They're also low maintenance, with minimal adjustments needed over time.\n\nCarbon wheels\nThe Velocity X1's aerodynamic carbon wheels provide excellent speed and responsiveness, helping you achieve your fastest times yet. They're also lightweight, reducing overall bike weight and making acceleration and handling even easier.\n\nShimano Ultegra groupset\nThe Shimano Ultegra groupset provides smooth shifting and reliable performance, ensuring you get the most out of every ride. With a wide range of gearing options, it's ideal for tackling any terrain, from steep climbs to fast descents.\n\n\n## Specifications\nFrameset\nFrame with Fork\tAluminium frame, internal cable routing, 135x9mm QR\nFork\tCarbon, hidden cable routing, 100x9mm QR\n\nWheels\nWheel front\tCarbon, 30mm deep rim, 23mm width, 100x9mm QR\nWheel rear\tCarbon, 30mm deep rim, 23mm width, 135x9mm QR\nSkewer front\t100x9mm QR\nSkewer rear\t135x9mm QR\nTire\tContinental Grand Prix 5000, 700x25mm, folding bead\nMax tire size\t700x28mm without fenders\n\nDrivetrain\nShifter\tShimano Ultegra R8020, 11 speed\nRear derailleur\tShimano Ultegra R8000, 11 speed\n*Crank\tSize: S, M\nShimano Ultegra R8000, 50/34T, 170mm length\nSize: L, XL\nShimano Ultegra R8000, 50/34T, 175mm length\nBottom bracket\tShimano BB-RS500-PB, PressFit\nCassette\tShimano Ultegra R8000, 11-30T, 11 speed\nChain\tShimano Ultegra HG701, 11 speed\nPedal\tNot included\nMax chainring size\t50/34T\n\nComponents\nSaddle\tBontrager Montrose Comp, steel rails, 138mm width\nSeatpost\tBontrager Comp, 6061 alloy, 27.2mm, 8mm offset, 330mm length\n*Handlebar\tSize: S, M, L\nBontrager Elite Aero VR-CF, alloy, 31.8mm, 93mm reach, 123mm drop, 400mm width\nSize: XL\nBontrager Elite Aero VR-CF, alloy, 31.8mm, 93mm reach, 123mm drop, 420mm width\nGrips\tBontrager Supertack Perf tape\n*Stem\tSize: S, M, L\nBontrager Elite Blendr, 31.8mm clamp, 7 degree, 90mm length\nSize: XL\nBontrager Elite Blendr, 31.8mm clamp, 7 degree, 100mm length\nBrake\tShimano Ultegra R8070 hydraulic disc, flat mount\nBrake rotor\tShimano RT800, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 8.15 kg / 17.97 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| S | 162 - 170 cm 5'4\" - 5'7\" | 74 - 78 cm 29\" - 31\" |\n| M | 170 - 178 cm 5'7\" - 5'10\" | 77 - 82 cm 30\" - 32\" |\n| L | 178 - 186 cm 5'10\" - 6'1\" | 82 - 86 cm 32\" - 34\" |\n| XL | 186 - 196 cm 6'1\" - 6'5\" | 87 - 92 cm 34\" - 36\" |\n\n\n## Geometry\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 50.0 | 52.0 | 54.0 | 56.0 |\n| B — Seat tube angle | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 13.0 | 15.0 | 17.0 | 19.0 |\n| D — Head angle | 71.0° | 72.0° | 72.0° | 72.5° |\n| E — Effective top tube | 53.7 | 55.0 | 56.5 | 58.0 |\n| F — Bottom bracket height | 27.5 | 27.5 | 27.5 | 27.5 |\n| G — Bottom bracket drop | 7.3 | 7.3 | 7.3 | 7.3 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 6.0 | 6.0 | 6.0 | 5.8 |\n| K — Wheelbase | 98.2 | 99.1 | 100.1 | 101.0 |\n| L — Standover | 75.2 | 78.2 | 81.1 | 84.1 |\n| M — Frame reach | 37.5 | 38.3 | 39.1 | 39.9 |\n| N — Frame stack | 53.3 | 55.4 | 57.4 | 59.5 |", - "price": 1799.99, - "tags": [ - "bicycle", - "touring bike" - ] - }, - { - "name": "Velocity V9", - "shortDescription": "Velocity V9 is a high-performance hybrid bike that combines speed and comfort for riders who demand the best of both worlds. The lightweight aluminum frame, along with the carbon fork and seat post, provide optimal stiffness and absorption to tackle any terrain. A 2x Shimano Deore drivetrain, hydraulic disc brakes, and 700c wheels with high-quality tires make it a versatile ride for commuters, fitness riders, and weekend adventurers alike.", - "description": "## Overview\nIt's right for you if...\nYou want a fast, versatile bike that can handle anything from commuting to weekend adventures. You value comfort as much as speed and performance. You want a reliable and durable bike that will last for years to come.\n\nThe tech you get\nA lightweight aluminum frame with a carbon fork and seat post, a 2x Shimano Deore drivetrain with a wide range of gearing, hydraulic disc brakes, and 700c wheels with high-quality tires. The Velocity V9 is designed for riders who demand both performance and comfort in one package.\n\nThe final word\nThe Velocity V9 is the perfect bike for riders who want speed and performance without sacrificing comfort. The lightweight aluminum frame and carbon components provide optimal stiffness and absorption, while the 2x Shimano Deore drivetrain and hydraulic disc brakes ensure precise shifting and stopping power. Whether you're commuting, hitting the trails, or training for your next race, the Velocity V9 has everything you need to achieve your goals.\n\n## Features\n\n2x drivetrain\nA 2x drivetrain means more versatility and a wider range of gearing options. Whether you're climbing hills or sprinting on the flats, the Velocity V9 has the perfect gear for any situation.\n\nCarbon components\nThe Velocity V9 features a carbon fork and seat post to provide optimal stiffness and absorption. This means you can ride faster and more comfortably over any terrain.\n\nHydraulic disc brakes\nHydraulic disc brakes provide unparalleled stopping power and modulation in any weather condition. You'll feel confident and in control no matter where you ride.\n\n## Specifications\nFrameset\nFrame with Fork\tAluminum frame with carbon fork and seat post, internal cable routing, fender mounts, 135x5mm ThruSkew\nFork\tCarbon fork, hidden fender mounts, flat mount disc, 5x100mm thru-skew\n\nWheels\nWheel front\tDouble wall aluminum rims, 700c, quick release hub\nWheel rear\tDouble wall aluminum rims, 700c, quick release hub\nTire\tKenda Kwick Tendril, puncture resistant, reflective sidewall, 700x32c\nMax tire size\t700x35c without fenders, 700x32c with fenders\n\nDrivetrain\nShifter\tShimano Deore, 10 speed\nFront derailleur\tShimano Deore\nRear derailleur\tShimano Deore\nCrank\tShimano Deore, 46-30T, 170mm (S/M), 175mm (L/XL)\nBottom bracket\tShimano BB52, 68mm, threaded\nCassette\tShimano Deore, 11-36T, 10 speed\nChain\tShimano HG54, 10 speed\nPedal\tWellgo alloy platform\n\nComponents\nSaddle\tVelo VL-2158, steel rails\nSeatpost\tCarbon seat post, 27.2mm\nHandlebar\tAluminum, 31.8mm clamp, 15mm rise, 680mm width\nGrips\tVelo ergonomic grips\nStem\tAluminum, 31.8mm clamp, 7 degree, 90mm length\nBrake\tShimano hydraulic disc, MT200 lever, MT200 caliper\nBrake rotor\tShimano RT56, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 11.5 kg / 25.35 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size | S | M | L | XL |\n|--------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 44.0 | 48.0 | 52.0 | 56.0 |\n| B — Seat tube angle | 74.5° | 74.0° | 73.5° | 73.0° |\n| C — Head tube length | 14.5 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 71.0° | 71.0° | 71.5° | 71.5° |\n| E — Effective top tube | 56.5 | 57.5 | 58.5 | 59.5 |\n| F — Bottom bracket height | 27.0 | 27.0 | 27.0 | 27.0 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 43.0 | 43.0 | 43.0 | 43.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 7.0 | 7.0 | 6.6 | 6.6 |\n| K — Wheelbase | 105.4 | 106.3 | 107.2 | 108.2 |\n| L — Standover | 73.2 | 77.1 | 81.2 | 85.1 |\n| M — Frame reach | 39.0 | 39.8 | 40.4 | 41.3 |\n| N — Frame stack | 57.0 | 58.5 | 60.0 | 61.5 |", - "price": 2199.99, - "tags": [ - "bicycle", - "electric bike", - "mountain bike" - ] - }, - { - "name": "Aero Pro X", - "shortDescription": "Aero Pro X is a high-end racing bike designed for serious cyclists who demand speed, agility, and superior performance. The lightweight carbon frame and fork, combined with the aerodynamic design, provide optimal stiffness and efficiency to maximize your speed. The bike features a 2x Shimano Ultegra drivetrain, hydraulic disc brakes, and 700c wheels with high-quality tires. Whether you're competing in a triathlon or climbing steep hills, Aero Pro X delivers exceptional performance and precision handling.", - "description": "## Overview\nIt's right for you if...\nYou are a competitive cyclist looking for a bike that is designed for racing. You want a bike that delivers exceptional speed, agility, and precision handling. You demand superior performance and reliability from your equipment.\n\nThe tech you get\nA lightweight carbon frame with an aerodynamic design, a carbon fork with hidden fender mounts, a 2x Shimano Ultegra drivetrain with a wide range of gearing, hydraulic disc brakes, and 700c wheels with high-quality tires. Aero Pro X is designed for serious cyclists who demand nothing but the best.\n\nThe final word\nAero Pro X is the ultimate racing bike for serious cyclists. The lightweight carbon frame and aerodynamic design deliver maximum speed and efficiency, while the 2x Shimano Ultegra drivetrain and hydraulic disc brakes ensure precise shifting and stopping power. Whether you're competing in a triathlon or a criterium race, Aero Pro X delivers the performance you need to win.\n\n## Features\n\nAerodynamic design\nThe Aero Pro X features an aerodynamic design that reduces drag and maximizes efficiency. The bike is optimized for speed and agility, so you can ride faster and farther with less effort.\n\nHydraulic disc brakes\nHydraulic disc brakes provide unrivaled stopping power and modulation in any weather condition. You'll feel confident and in control no matter where you ride.\n\nCarbon components\nThe Aero Pro X features a carbon fork with hidden fender mounts to provide optimal stiffness and absorption. This means you can ride faster and more comfortably over any terrain.\n\n## Specifications\nFrameset\nFrame with Fork\tCarbon frame with an aerodynamic design, internal cable routing, 3s chain keeper, 142x12mm thru-axle\nFork\tCarbon fork with hidden fender mounts, flat mount disc, 100x12mm thru-axle\n\nWheels\nWheel front\tDouble wall carbon rims, 700c, thru-axle hub\nWheel rear\tDouble wall carbon rims, 700c, thru-axle hub\nTire\tContinental Grand Prix 5000, folding bead, 700x25c\nMax tire size\t700x28c without fenders, 700x25c with fenders\n\nDrivetrain\nShifter\tShimano Ultegra, 11 speed\nFront derailleur\tShimano Ultegra\nRear derailleur\tShimano Ultegra\nCrank\tShimano Ultegra, 52-36T, 170mm (S), 172.5mm (M), 175mm (L/XL)\nBottom bracket\tShimano BB72, 68mm, PressFit\nCassette\tShimano Ultegra, 11-30T, 11 speed\nChain\tShimano HG701, 11 speed\nPedal\tNot included\n\nComponents\nSaddle\tBontrager Montrose Elite, carbon rails, 138mm width\nSeatpost\tCarbon seat post, 27.2mm, 20mm offset\nHandlebar\tBontrager XXX Aero, carbon, 31.8mm clamp, 75mm reach, 125mm drop\nGrips\tBontrager Supertack Perf tape\nStem\tBontrager Pro, 31.8mm clamp, 7 degree, 90mm length\nBrake\tShimano hydraulic disc, Ultegra lever, Ultegra caliper\nBrake rotor\tShimano RT800, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 8.36 kg / 18.42 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider Height |\n|:----:|:-------------------------:|\n| S | 155 - 165 cm 5'1\" - 5'5\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" |\n\n## Geometry\n| Frame size | S | M | L | XL |\n|--------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 50.6 | 52.4 | 54.3 | 56.2 |\n| B — Seat tube angle | 75.5° | 74.5° | 73.5° | 72.5° |\n| C — Head tube length | 12.0 | 14.0 | 16.0 | 18.0 |\n| D — Head angle | 72.5° | 73.0° | 73.5° | 74.0° |\n| E — Effective top tube | 53.8 | 55.4 | 57.0 | 58.6 |\n| F — Bottom bracket height | 26.5 | 26.5 | 26.5 | 26.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 6.0 | 6.0 | 6.0 | 6.0 |\n| K — Wheelbase | 97.1 | 98.7 | 100.2 | 101.8 |\n| L — Standover | 73.8 | 76.2 | 78.5 | 80.8 |\n| M — Frame reach | 38.8 | 39.5 | 40.2 | 40.9 |\n| N — Frame stack | 52.8 | 54.7 | 56.6 | 58.5 |", - "price": 1599.99, - "tags": [ - "bicycle", - "road bike" - ] - }, - { - "name": "Voltex+ Ultra Lowstep", - "shortDescription": "Voltex+ Ultra Lowstep is a high-performance electric hybrid bike designed for riders who seek speed, comfort, and reliability during their everyday rides. Equipped with a powerful and efficient Voltex Drive Pro motor and a fully-integrated 600Wh battery, this e-bike allows you to cover longer distances on a single charge. The Voltex+ Ultra Lowstep comes with premium components that prioritize comfort and safety, such as a suspension seatpost, wide and stable tires, and integrated lights.", - "description": "## Overview\n\nIt's right for you if...\nYou want an e-bike that provides a boost for faster rides and effortless usage. Durability is crucial, and you need a bike with one of the most powerful and efficient motors.\n\nThe tech you get\nA lightweight Delta Carbon Fiber frame with an ultra-lowstep design, a Voltex Drive Pro (350W, 75Nm) motor capable of maintaining speeds up to 30 mph, an extended range 600Wh battery integrated into the frame, and a Voltex Control Panel. Additionally, it features a 12-speed Shimano drivetrain, hydraulic disc brakes for optimal all-weather stopping power, a suspension seatpost, wide puncture-resistant tires for added stability, ergonomic grips, a kickstand, lights, and a cargo rack.\n\nThe final word\nThis bike offers enhanced enjoyment and ease of use on long commutes, leisure rides, and adventures. With its extended-range battery, powerful Voltex motor, user-friendly controller, and a seatpost that smooths out road vibrations, it guarantees an exceptional riding experience.\n\n## Features\n\nUltra-fast assistance\n\nExperience speeds up to 30 mph with the cutting-edge Voltex Drive Pro motor, allowing you to breeze through errands, commutes, and joyrides.\n\n## Specs\n\nFrameset\n- Frame: Delta Carbon Fiber, Removable Integrated Battery (RIB), sleek welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: Voltex Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: Voltex Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: Voltex E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore XT M8100, 12-speed\n- Rear derailleur: Shimano Deore XT M8100, long cage\n- Crank: Voltex alloy, 170mm length\n- Chainring: FSA, 44T, aluminum with guard\n- Cassette: Shimano Deore XT M8100, 10-51, 12-speed\n- Chain: KMC E12 Turbo\n- Pedal: Voltex Urban pedals\n\nComponents\n- Saddle: Voltex Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar: Voltex alloy, 31.8mm, comfort sweep, 620mm width (XS, S, M), 660mm width (L)\n- Grips: Voltex Satellite Elite, alloy lock-on\n- Stem: Voltex alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length (XS, S), 105mm length (M, L)\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT520 hydraulic disc\n- Brake rotor: Shimano RT56, 6-bolt, 180mm (XS, S, M, L), 160mm (XS, S, M, L)\n\nAccessories\n- Battery: Voltex PowerTube 600Wh\n- Charger: Voltex compact 2A, 100-240V\n- Computer: Voltex Control Panel\n- Motor: Voltex Drive Pro, 75Nm, 30mph\n- Light: Voltex Solo for e-bike, taillight (XS, S, M, L), Voltex MR8, 180 lumen, 60 lux, LED, headlight (XS, S, M, L)\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: Voltex-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender: Voltex wide (XS, S, M, L), Voltex plastic (XS, S, M, L)\n\nWeight\n- Weight: M - 20.50 kg / 45.19 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 330 pounds (150 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 38.0 | 43.0 | 48.0 | 53.0 |\n| B — Seat tube angle | 70.5° | 70.5° | 70.5° | 70.5° |\n| C — Head tube length | 15.0 | 15.0 | 17.0 | 19.0 |\n| D — Head angle | 69.2° | 69.2° | 69.2° | 69.2° |\n| E — Effective top tube | 57.2 | 57.7 | 58.8 | 60.0 |\n| F — Bottom bracket height | 30.3 | 30.3 | 30.3 | 30.3 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.5 | 48.5 | 48.5 | 48.5 |\n| I — Offset | 5.0 | 5.0 | 5.0 | 5.0 |\n| J — Trail | 9.0 | 9.0 | 9.0 | 9.0 |\n| K — Wheelbase | 111.8 | 112.3 | 113.6 | 114.8 |\n| L — Standover | 42.3 | 42.3 | 42.3 | 42.3 |\n| M — Frame reach | 36.0 | 38.0 | 38.0 | 38.0 |\n| N — Frame stack | 62.0 | 62.0 | 63.9 | 65.8 |\n| Stem length | 8.0 | 8.5 | 8.5 | 10.5 |\n\nPlease note that the specifications and features listed above are subject to change and may vary based on different models and versions of the Voltex+ Ultra Lowstep bike.", - "price": 2999.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "SwiftRide Hybrid", - "shortDescription": "SwiftRide Hybrid is a versatile and efficient bike designed for riders who want a smooth and enjoyable ride on various terrains. It incorporates advanced technology and high-quality components to provide a comfortable and reliable cycling experience.", - "description": "## Overview\n\nIt's right for you if...\nYou are looking for a bike that combines the benefits of an electric bike with the versatility of a hybrid. You value durability, speed, and ease of use.\n\nThe tech you get\nThe SwiftRide Hybrid features a lightweight and durable aluminum frame, making it easy to handle and maneuver. It is equipped with a powerful electric motor that offers a speedy assist, helping you reach speeds of up to 25 mph. The bike comes with a removable and fully-integrated 500Wh battery, providing a long-range capacity for extended rides. It also includes a 10-speed Shimano drivetrain, hydraulic disc brakes for precise stopping power, wide puncture-resistant tires for stability, and integrated lights for enhanced visibility.\n\nThe final word\nThe SwiftRide Hybrid is designed for riders who want a bike that can handle daily commutes, recreational rides, and adventures. With its efficient motor, intuitive controls, and comfortable features, it offers an enjoyable and hassle-free riding experience.\n\n## Features\n\nEfficient electric assist\nExperience the thrill of effortless riding with the powerful electric motor that provides a speedy assist, making your everyday rides faster and more enjoyable.\n\n## Specs\n\nFrameset\n- Frame: Lightweight Aluminum, Removable Integrated Battery (RIB), rack & fender mounts, internal routing, 135x5mm QR\n- Fork: SwiftRide Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: SwiftRide Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: SwiftRide E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: SwiftRide City pedals\n\nComponents\n- Saddle: SwiftRide Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - SwiftRide alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - SwiftRide alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: SwiftRide Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - SwiftRide alloy quill, 31.8mm clamp, adjustable rise, 85mm length\n - Size: M, L - SwiftRide alloy quill, 31.8mm clamp, adjustable rise, 105mm length\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: SwiftRide PowerTube 500Wh\n- Charger: SwiftRide compact 2A, 100-240V\n- Computer: SwiftRide Purion\n- Motor: SwiftRide Performance Line Sport, 65Nm, 25mph\n- Light:\n - Size: XS, S, M, L - SwiftRide SOLO for e-bike, taillight\n - Size: XS, S, M, L - SwiftRide MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: SwiftRide-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - SwiftRide wide\n - Size: XS, S, M, L - SwiftRide plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm (4'10\" - 5'1\") | 69 - 73 cm (27\" - 29\") |\n| S | 155 - 165 cm (5'1\" - 5'5\") | 72 - 78 cm (28\" - 31\") |\n| M | 165 - 175 cm (5'5\" - 5'9\") | 77 - 83 cm (30\" - 33\") |\n| L | 175 - 186 cm (5'9\" - 6'1\") | 82 - 88 cm (32\" - 35\") |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", - "price": 3999.99, - "tags": [ - "bicycle", - "mountain bike", - "professional" - ] - }, - { - "name": "RoadRunner E-Speed Lowstep", - "shortDescription": "RoadRunner E-Speed Lowstep is a high-performance electric hybrid designed for riders seeking speed and excitement on their daily rides. It is equipped with a powerful and reliable ThunderBolt drive unit that offers exceptional acceleration. The bike features a fully-integrated 500Wh battery, allowing riders to cover longer distances on a single charge. With its comfortable and safe components, including a suspension seatpost, wide and stable tires, and integrated lights, the RoadRunner E-Speed Lowstep ensures a smooth and enjoyable ride.", - "description": "## Overview\n\nIt's right for you if...\nYou're looking for an e-bike that provides an extra boost to reach your destination quickly and effortlessly. You prioritize durability and want a bike with one of the fastest motors available.\n\nThe tech you get\nA lightweight and sturdy ThunderBolt aluminum frame with a lowstep geometry. The bike is equipped with a ThunderBolt Performance Sport (250W, 65Nm) drive unit capable of reaching speeds up to 28 mph. It features a long-range 500Wh battery fully integrated into the frame and a ThunderBolt controller. Additionally, the bike has a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThe RoadRunner E-Speed Lowstep is designed to provide enjoyment and ease of use on longer commutes, recreational rides, and adventurous journeys. Its long-range battery, fast ThunderBolt motor, intuitive controller, and road-smoothing suspension seatpost make it the perfect choice for riders seeking both comfort and speed.\n\n## Features\n\nSuper speedy assist\n\nThe ThunderBolt Performance Sport drive unit allows you to accelerate up to 28mph, making errands, commutes, and joyrides a breeze.\n\n## Specs\n\nFrameset\n- Frame: ThunderBolt Smooth Aluminum, Removable Integrated Battery (RIB), sleek welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: RoadRunner Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: ThunderBolt DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: ThunderBolt DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: ThunderBolt Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: ThunderBolt E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: RoadRunner City pedals\n\nComponents\n- Saddle: RoadRunner Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - RoadRunner alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - RoadRunner alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: RoadRunner Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - RoadRunner alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\n - Size: M, L - RoadRunner alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: ThunderBolt PowerTube 500Wh\n- Charger: ThunderBolt compact 2A, 100-240V\n- Computer: ThunderBolt Purion\n- Motor: ThunderBolt Performance Line Sport, 65Nm, 28mph\n- Light:\n - Size: XS, S, M, L - ThunderBolt SOLO for e-bike, taillight\n - Size: XS, S, M, L - ThunderBolt MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: MIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - RoadRunner wide\n - Size: XS, S, M, L - RoadRunner plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", - "price": 4999.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "Hyperdrive Turbo X1", - "shortDescription": "Hyperdrive Turbo X1 is a high-performance electric bike designed for riders seeking an exhilarating experience on their daily rides. It features a powerful and efficient Hyperdrive Sport drive unit and a sleek, integrated 500Wh battery for extended range. This e-bike is equipped with top-of-the-line components prioritizing comfort and safety, including a suspension seatpost, wide and stable tires, and integrated lights.", - "description": "## Overview\n\nIt's right for you if...\nYou crave the thrill of an e-bike that can accelerate rapidly, reaching high speeds effortlessly. You value durability and are looking for a bike that is equipped with one of the fastest motors available.\n\nThe tech you get\nA lightweight Hyper Alloy frame with a lowstep geometry, a Hyperdrive Sport (300W, 70Nm) drive unit capable of maintaining speeds up to 30 mph, a long-range 500Wh battery seamlessly integrated into the frame, and an intuitive Hyper Control controller. Additionally, it features a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for enhanced stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThis bike is designed for riders seeking enjoyment and convenience on longer commutes, recreational rides, and thrilling adventures. With its long-range battery, high-speed motor, user-friendly controller, and smooth-riding suspension seatpost, the Hyperdrive Turbo X1 guarantees an exceptional e-biking experience.\n\n## Features\n\nHyperboost Acceleration\nExperience adrenaline-inducing rides with the powerful Hyperdrive Sport drive unit that enables quick acceleration and effortless cruising through errands, commutes, and joyrides.\n\n## Specs\n\nFrameset\nFrame\tHyper Alloy, Removable Integrated Battery (RIB), seamless welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\nFork\tHyper Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\nMax compatible fork travel\t50mm\n\nWheels\nHub front\tFormula DC-20, alloy, 6-bolt, 5x100mm QR\nSkewer front\t132x5mm QR, ThruSkew\nHub rear\tFormula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\nSkewer rear\t153x5mm bolt-on\nRim\tHyper Connection, double-wall, 32-hole, 20 mm width, Schrader valve\nTire\tHyper E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\nMax tire size\t700x50mm with or without fenders\n\nDrivetrain\nShifter\tShimano Deore M4100, 10 speed\nRear derailleur\tShimano Deore M5120, long cage\nCrank\tProWheel alloy, 170mm length\nChainring\tFSA, 42T, steel w/guard\nCassette\tShimano Deore M4100, 11-42, 10 speed\nChain\tKMC E10\nPedal\tHyper City pedals\n\nComponents\nSaddle\tHyper Boulevard\nSeatpost\tAlloy, suspension, 31.6mm, 300mm length\n*Handlebar\tSize: XS, S, M\nHyper alloy, 31.8mm, comfort sweep, 620mm width\nSize: L\nHyper alloy, 31.8mm, comfort sweep, 660mm width\nGrips\tHyper Satellite Elite, alloy lock-on\n*Stem\tSize: XS, S\nHyper alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\nSize: M, L\nHyper alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\nHeadset\tVP sealed cartridge, 1-1/8'', threaded\nBrake\tShimano MT200 hydraulic disc\n*Brake rotor\tSize: XS, S, M, L\nShimano RT26, 6-bolt,180mm\nSize: XS, S, M, L\nShimano RT26, 6-bolt,160mm\n\nAccessories\nBattery\tHyper PowerTube 500Wh\nCharger\tHyper compact 2A, 100-240V\nComputer\tHyper Control\nMotor\tHyperdrive Sport, 70Nm, 30mph\n*Light\tSize: XS, S, M, L\nSpanninga SOLO for e-bike, taillight\nSize: XS, S, M, L\nHerrmans MR8, 180 lumen, 60 lux, LED, headlight\nKickstand\tAdjustable length rear mount alloy kickstand\nCargo rack\tMIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n*Fender\tSize: XS, S, M, L\nSKS wide\nSize: XS, S, M, L\nSKS plastic\n\nWeight\nWeight\tM - 22.30 kg / 49.17 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", - "price": 1999.99, - "tags": [ - "bicycle", - "city bike", - "professional" - ] - }, - { - "name": "Horizon+ Evo Lowstep", - "shortDescription": "The Horizon+ Evo Lowstep is a versatile electric hybrid bike designed for riders seeking a thrilling and efficient riding experience on a variety of terrains. With its powerful Bosch Performance Line Sport drive unit and integrated 500Wh battery, this e-bike enables riders to cover long distances with ease. Equipped with features prioritizing comfort and safety, such as a suspension seatpost, stable tires, and integrated lights, the Horizon+ Evo Lowstep is a reliable companion for everyday rides.", - "description": "## Overview\n\nIt's right for you if...\nYou desire the convenience and speed of an e-bike to enhance your riding, and you want an intuitive and durable bicycle. You prioritize having one of the fastest motors developed by Bosch.\n\nThe tech you get\nA lightweight Alpha Smooth Aluminum frame with a lowstep geometry, a Bosch Performance Line Sport (250W, 65Nm) drive unit capable of sustaining speeds up to 28 mph, a fully encased 500Wh battery integrated into the frame, and a Bosch Purion controller. Additionally, it features a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for improved stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThe Horizon+ Evo Lowstep offers an enjoyable and user-friendly riding experience for longer commutes, recreational rides, and adventures. It boasts an extended range battery, a high-performance Bosch motor, an intuitive controller, and a suspension seatpost for a smooth ride on various road surfaces.\n\n## Features\n\nSuper speedy assist\nExperience effortless cruising through errands, commutes, and joyrides with the new Bosch Performance Sport drive unit, allowing acceleration of up to 28 mph.\n\n## Specs\n\nFrameset\n- Frame: Alpha Platinum Aluminum, Removable Integrated Battery (RIB), smooth welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: Horizon Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Front Hub: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Front Skewer: 132x5mm QR, ThruSkew\n- Rear Hub: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Rear Skewer: 153x5mm bolt-on\n- Rim: Bontrager Connection, double-wall, 32-hole, 20mm width, Schrader valve\n- Tire: Bontrager E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10-speed\n- Rear Derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10-speed\n- Chain: KMC E10\n- Pedal: Bontrager City pedals\n\nComponents\n- Saddle: Bontrager Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - Bontrager alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - Bontrager alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: Bontrager Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - Bontrager alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\n - Size: M, L - Bontrager alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\n- Headset: VP sealed cartridge, 1-1/8\", threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: Bosch PowerTube 500Wh\n- Charger: Bosch compact 2A, 100-240V\n- Computer: Bosch Purion\n- Motor: Bosch Performance Line Sport, 65Nm, 28mph\n- Light:\n - Size: XS, S, M, L - Spanninga SOLO for e-bike, taillight\n - Size: XS, S, M, L - Herrmans MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: MIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - SKS wide\n - Size: XS, S, M, L - SKS plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", - "price": 4499.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "FastRider X1", - "shortDescription": "FastRider X1 is a high-performance e-bike designed for riders seeking speed and long-distance capabilities. Equipped with a powerful motor and a high-capacity battery, the FastRider X1 is perfect for daily commuters and e-bike enthusiasts. It boasts a sleek and functional design, making it a great alternative to car transportation. The bike also features a smartphone controller for easy navigation and entertainment options.", - "description": "## Overview\nIt's right for you if...\nYou're looking for an e-bike that offers both speed and endurance. The FastRider X1 comes with a high-performance motor and a long-lasting battery, making it ideal for long-distance rides.\n\nThe tech you get\nThe FastRider X1 features a state-of-the-art motor and a spacious battery, ensuring a fast and efficient ride.\n\nThe final word\nWith the powerful motor and long-range battery, the FastRider X1 allows you to cover more distance at higher speeds.\n\n## Features\nConnect Your Ride with the FastRider App\nDownload the FastRider app and transform your smartphone into an on-board computer. Easily dock and charge your phone with the smartphone controller, and use the thumb pad on your handlebar to make calls, listen to music, get turn-by-turn directions, and more. The app also allows you to connect with fitness and health apps, syncing your routes and ride data.\n\nGoodbye, Car. Hello, Extended Range!\nWith the option to add the Range Boost feature, you can attach a second long-range battery to your FastRider X1, doubling the distance and time between charges. This enhancement allows you to ride longer, commute farther, and take on more adventurous routes.\n\nWhat is the range?\nTo estimate the distance you can travel on a single charge, use our range calculator tool. It automatically fills in the variables for this specific bike model and assumes an average rider, but you can adjust the settings to get the most accurate estimate for your needs.\n\n## Specifications\nFrameset\n- Frame: High-performance hydroformed alloy, Removable Integrated Battery, Range Boost-compatible, internal cable routing, Motor Armour, post-mount disc, 135x5 mm QR\n- Fork: FastRider rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru axle, post mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: FastRider sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: FastRider Switch thru axle, removable lever\n- Rear Hub: FastRider alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: FastRider MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: Size: M, L, XL - 14g stainless steel, black\n- Tire: FastRider E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Size: M, L, XL - Shimano Deore M5120, long cage\n- Crank: Size: M - FastRider alloy, 170mm length / Size: L, XL - FastRider alloy, 175mm length\n- Chainring: FastRider 46T narrow/wide alloy, w/alloy guard\n- Cassette: Size: M, L, XL - Shimano Deore M4100, 11-42, 10 speed\n- Chain: Size: M, L, XL - KMC E10 / Size: M, L, XL - KMC X10e\n- Pedal: Size: M, L, XL - FastRider City pedals / Size: M, L, XL - Wellgo C157, boron axle, plastic body / Size: M, L, XL - slip-proof aluminum pedals with reflectors\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: FastRider Commuter Comp\n- Seatpost: FastRider Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Size: M - FastRider alloy, 31.8mm, 15mm rise, 600mm width / Size: L, XL - FastRider alloy, 31.8mm, 15mm rise, 660mm width\n- Grips: FastRider Satellite Elite, alloy lock-on\n- Stem: Size: M - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length / Size: L - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length / Size: XL - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\n- Headset: Size: M, L, XL - FSA IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom / Size: M, L, XL - FSA Integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\n- Battery: FastRider PowerTube 625Wh\n- Charger: FastRider standard 4A, 100-240V\n- Motor: FastRider Performance Speed, 85 Nm, 28 mph / 45 kph\n- Light: Size: M, L, XL - FastRider taillight, 50 lumens / Size: M, L, XL - FastRider headlight, 500 lumens\n- Kickstand: Size: M, L, XL - Rear mount, alloy / Size: M, L, XL - Adjustable length alloy kickstand\n- Cargo rack: FastRider integrated rear rack, aluminum\n- Fender: FastRider custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n\nWeight limit\n- This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", - "price": 5499.99, - "tags": [ - "bicycle", - "mountain bike", - "professional" - ] - }, - { - "name": "SonicRide 8S", - "shortDescription": "SonicRide 8S is a high-performance e-bike designed for riders who crave speed and long-distance capabilities. The advanced SonicDrive motor provides powerful assistance up to 28 mph, combined with a durable and long-lasting battery for extended rides. With its sleek design and thoughtful features, the SonicRide 8S is perfect for those who prefer the freedom of riding a bike over driving a car. Plus, it comes equipped with a smartphone controller for easy navigation, music, and more.", - "description": "## Overview\nIt's right for you if...\nYou want a fast and efficient e-bike that can take you long distances. The SonicRide 8S features a hydroformed aluminum frame with a concealed 625Wh battery, a high-powered SonicDrive motor, and a Smartphone Controller. It also includes essential accessories such as lights, fenders, and a rear rack.\n\nThe tech you get\nThe SonicRide 8S is equipped with the fastest SonicDrive motor, ensuring exhilarating rides at high speeds. The long-range battery is perfect for commuters and riders looking to explore new horizons.\n\nThe final word\nWith the SonicDrive motor and long-lasting battery, you can enjoy extended rides at higher speeds.\n\n## Features\n\nConnect Your Ride with SonicRide App\nDownload the SonicRide app and transform your phone into an onboard computer. Simply attach it to the Smartphone Controller for docking and charging. Use the thumb pad on your handlebar to control calls, music, directions, and more. The Bluetooth® wireless technology allows you to connect with fitness and health apps, syncing your routes and ride data.\n\nSay Goodbye to Limited Range with Range Boost!\nExperience the convenience of Range Boost, an additional long-range 500Wh battery that seamlessly attaches to your bike's down tube. This upgrade allows you to double your distance and time between charges, enabling longer commutes and more adventurous rides. Range Boost is compatible with select SonicRide electric bike models.\n\nWhat is the range?\nFor an accurate estimate of how far you can ride on a single charge, use SonicRide's range calculator. We have pre-filled the variables for this specific bike model and the average rider, but you can adjust them to obtain the most accurate estimate.\n\n## Specifications\nFrameset\n- Frame: High-performance hydroformed alloy, Removable Integrated Battery, Range Boost-compatible, internal cable routing, Motor Armour, post-mount disc, 135x5 mm QR\n- Fork: SonicRide rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru axle, post mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: SonicRide sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: SonicRide Switch thru axle, removable lever\n- Rear Hub: SonicRide alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: SonicRide MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: Size: M, L, XL - 14g stainless steel, black\n- Tire: SonicRide E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear Derailleur: Size: M, L, XL - Shimano Deore M5120, long cage\n- Crank: Size: M - SonicRide alloy, 170mm length; Size: L, XL - SonicRide alloy, 175mm length\n- Chainring: SonicRide 46T narrow/wide alloy, with alloy guard\n- Cassette: Size: M, L, XL - Shimano Deore M4100, 11-42, 10 speed\n- Chain: Size: M, L, XL - KMC E10; Size: M, L, XL - KMC X10e\n- Pedal: Size: M, L, XL - SonicRide City pedals; Size: M, L, XL - Wellgo C157, boron axle, plastic body; Size: M, L, XL - slip-proof aluminum pedals with reflectors\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: SonicRide Commuter Comp\n- Seatpost: SonicRide Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Size: M - SonicRide alloy, 31.8mm, 15mm rise, 600mm width; Size: L, XL - SonicRide alloy, 31.8mm, 15mm rise, 660mm width\n- Grips: SonicRide Satellite Elite, alloy lock-on\n- Stem: Size: M - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length; Size: L - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length; Size: XL - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\n- Headset: Size: M, L, XL - SonicRide IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom; Size: M, L, XL - SonicRide Integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\n- Battery: SonicRide PowerTube 625Wh\n- Charger: SonicRide standard 4A, 100-240V\n- Motor: SonicRide Performance Speed, 85 Nm, 28 mph / 45 kph\n- Light: Size: M, L, XL - SonicRide Lync taillight, 50 lumens; Size: M, L, XL - SonicRide Lync headlight, 500 lumens\n- Kickstand: Size: M, L, XL - Rear mount, alloy; Size: M, L, XL - Adjustable length alloy kickstand\n- Cargo rack: SonicRide integrated rear rack, aluminum\n- Fender: SonicRide custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm / 5'5\" - 5'9\" | 77 - 83 cm / 30\" - 33\" |\n| L | 175 - 186 cm / 5'9\" - 6'1\" | 82 - 88 cm / 32\" - 35\" |\n| XL | 186 - 197 cm / 6'1\" - 6'6\" | 87 - 93 cm / 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |", - "price": 5999.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "SwiftVolt Pro", - "shortDescription": "SwiftVolt Pro is a high-performance e-bike designed for riders seeking a thrilling and fast riding experience. Equipped with a powerful SwiftDrive motor that provides assistance up to 30 mph and a long-lasting battery, this bike is perfect for long-distance commuting and passionate e-bike enthusiasts. The sleek and innovative design features cater specifically to individuals who prioritize cycling over driving. Additionally, the bike is seamlessly integrated with your smartphone, allowing you to use it for navigation, music, and more.", - "description": "## Overview\nThis bike is ideal for you if:\n- You desire a sleek and modern hydroformed aluminum frame that houses a 700Wh battery.\n- You want to maintain high speeds of up to 30 mph with the assistance of the SwiftDrive motor.\n- You appreciate the convenience of using your smartphone as a controller, which can be docked and charged on the handlebar.\n\n## Features\n\nConnect with SwiftSync App\nBy downloading the SwiftSync app, your smartphone becomes an interactive on-board computer. Attach it to the handlebar-mounted controller for easy access and charging. With the thumb pad, you can make calls, listen to music, receive turn-by-turn directions, and connect with fitness and health apps to track your routes and ride data via Bluetooth® wireless technology.\n\nEnhanced Range with BoostMax\nBoostMax offers the capability to attach a second 700Wh Swift battery to the downtube of your bike, effectively doubling the distance and time between charges. This allows for extended rides, longer commutes, and more significant adventures. BoostMax is compatible with select Swift electric bike models.\n\nRange Estimation\nFor an estimate of how far you can ride on a single charge, consult the Swift range calculator. The variables are automatically populated based on this bike model and the average rider, but you can modify them to obtain the most accurate estimate.\n\n## Specifications\nFrameset\n- Frame: Lightweight hydroformed alloy, Removable Integrated Battery, BoostMax-compatible, internal cable routing, post-mount disc, 135x5 mm QR\n- Fork: SwiftVolt rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru-axle, post-mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: Swift sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: Swift Switch thru-axle, removable lever\n- Rear Hub: Swift alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: SwiftRim, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: 14g stainless steel, black\n- Tire: Swift E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear Derailleur: Shimano Deore M5120, long cage\n- Crank: Swift alloy, 170mm length\n- Chainring: Swift 46T narrow/wide alloy, w/alloy guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: Swift City pedals\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: Swift Commuter Comp\n- Seatpost: Swift Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Swift alloy, 31.8mm, 15mm rise, 600mm width (M), 660mm width (L, XL)\n- Grips: Swift Satellite Elite, alloy lock-on\n- Stem: Swift alloy, 31.8mm, Blendr compatible, 7 degree, 70mm length (M), 90mm length (L), 100mm length (XL)\n- Headset: FSA IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brakes: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake Rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max 180mm front & rear\n\nAccessories\n- Battery: Swift PowerTube 700Wh\n- Charger: Swift standard 4A, 100-240V\n- Motor: SwiftDrive, 90 Nm, 30 mph / 48 kph\n- Light: Swift Lync taillight, 50 lumens (M, L, XL), Swift Lync headlight, 500 lumens (M, L, XL)\n- Kickstand: Rear mount, alloy (M, L, XL), Adjustable length alloy kickstand (M, L, XL)\n- Cargo rack: SwiftVolt integrated rear rack, aluminum\n- Fender: Swift custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:---------------------:|:-------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", - "price": 2499.99, - "tags": [ - "bicycle", - "city bike", - "professional" - ] - }, - { - "name": "AgileEon 9X", - "shortDescription": "AgileEon 9X is a high-performance e-bike designed for riders seeking speed and endurance. Equipped with a robust motor and an extended battery life, this bike is perfect for long-distance commuters and avid e-bike enthusiasts. It boasts innovative features tailored for individuals who prioritize cycling over driving. Additionally, the bike integrates seamlessly with your smartphone, allowing you to access navigation, music, and more.", - "description": "## Overview\nIt's right for you if...\nYou crave speed and want to cover long distances efficiently. The AgileEon 9X features a sleek hydroformed aluminum frame that houses a powerful motor, along with a large-capacity battery for extended rides. It comes equipped with a 10-speed drivetrain, front and rear lighting, fenders, and a rear rack.\n\nThe tech you get\nDesigned for those constantly on the move, this bike includes a state-of-the-art motor and a high-capacity battery, making it an excellent choice for lengthy commutes.\n\nThe final word\nWith the AgileEon 9X, you can push your boundaries and explore new horizons thanks to its powerful motor and long-lasting battery.\n\n## Features\n\nConnect Your Ride with RideMate App\nMake use of the RideMate app to transform your smartphone into an onboard computer. Simply attach it to the RideMate controller to dock and charge, then utilize the thumb pad on your handlebar to make calls, listen to music, receive turn-by-turn directions, and more. The bike also supports Bluetooth® wireless technology, enabling seamless connectivity with fitness and health apps for route syncing and ride data.\n\nGoodbye, car. Hello, Extended Range!\nEnhance your riding experience with the Extended Range option, which allows for the attachment of an additional high-capacity 500Wh battery to your bike's downtube. This doubles the distance and time between charges, enabling longer rides, extended commutes, and more significant adventures. The Extended Range feature is compatible with select AgileEon electric bike models.\n\nWhat is the range?\nTo determine how far you can ride on a single charge, you can utilize the range calculator provided by AgileEon. We have pre-filled the variables for this specific model and an average rider, but adjustments can be made for a more accurate estimation.\n\n## Specifications\nFrameset\nFrame: High-performance hydroformed alloy, Removable Integrated Battery, Extended Range-compatible, internal cable routing, Motor Armor, post-mount disc, 135x5 mm QR\nFork: AgileEon rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru-axle, post-mount disc brake\nMax compatible fork travel: 63mm\n\nWheels\nFront Hub: AgileEon sealed bearing, 32-hole 15mm alloy thru-axle\nFront Skewer: AgileEon Switch thru-axle, removable lever\nRear Hub: AgileEon alloy, sealed bearing, 6-bolt, 135x5mm QR\nRear Skewer: 148x5mm bolt-on\nRim: AgileEon MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\nSpokes:\n- Size: M, L, XL: 14g stainless steel, black\nTire: AgileEon E6 Hard-Case Lite, reflective strip, 27.5x2.40''\nMax tire size: 27.5x2.40\"\n\nDrivetrain\nShifter: Shimano Deore M4100, 10-speed\nRear derailleur:\n- Size: M, L, XL: Shimano Deore M5120, long cage\nCrank:\n- Size: M: AgileEon alloy, 170mm length\n- Size: L, XL: AgileEon alloy, 175mm length\nChainring: AgileEon 46T narrow/wide alloy, with alloy guard\nCassette:\n- Size: M, L, XL: Shimano Deore M4100, 11-42, 10-speed\nChain:\n- Size: M, L, XL: KMC E10\nPedal:\n- Size: M, L, XL: AgileEon City pedals\nMax chainring size: 1x: 48T\n\nComponents\nSaddle: AgileEon Commuter Comp\nSeatpost: AgileEon Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\nHandlebar:\n- Size: M: AgileEon alloy, 31.8mm, 15mm rise, 600mm width\n- Size: L, XL: AgileEon alloy, 31.8mm, 15mm rise, 660mm width\nGrips: AgileEon Satellite Elite, alloy lock-on\nStem:\n- Size: M: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length\n- Size: L: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length\n- Size: XL: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\nHeadset:\n- Size: M, L, XL: AgileEon IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\nBrake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\nBrake rotor: Shimano RT56, 6-bolt, 180mm\nRotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\nBattery: AgileEon PowerTube 625Wh\nCharger: AgileEon standard 4A, 100-240V\nMotor: AgileEon Performance Speed, 85 Nm, 28 mph / 45 kph\nLight:\n- Size: M, L, XL: AgileEon taillight, 50 lumens\n- Size: M, L, XL: AgileEon headlight, 500 lumens\nKickstand:\n- Size: M, L, XL: Rear mount, alloy\n- Size: M, L, XL: Adjustable length alloy kickstand\nCargo rack: AgileEon integrated rear rack, aluminum\nFender: AgileEon custom aluminum\n\nWeight\nWeight: M - 25.54 kg / 56.3 lbs\nWeight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", - "price": 3499.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "Stealth R1X Pro", - "shortDescription": "Stealth R1X Pro is a high-performance carbon road bike designed for riders who crave speed and exceptional handling. With its aerodynamic tube shaping, disc brakes, and lightweight carbon wheels, the Stealth R1X Pro offers unparalleled performance for competitive road cycling.", - "description": "## Overview\nIt's right for you if...\nYou're a competitive cyclist looking for a road bike that offers superior performance in terms of speed, handling, and aerodynamics. You want a complete package that includes lightweight carbon wheels, without the need for future upgrades.\n\nThe tech you get\nThe Stealth R1X Pro features a lightweight and aerodynamic carbon frame, an advanced carbon fork, high-performance Shimano Ultegra 11-speed drivetrain, and powerful Ultegra disc brakes. The bike also comes equipped with cutting-edge Bontrager Aeolus Elite 35 carbon wheels.\n\nThe final word\nThe Stealth R1X Pro stands out with its combination of a fast and aerodynamic frame, high-end drivetrain, and top-of-the-line carbon wheels. Whether you're racing on local roads, participating in pro stage races, or engaging in hill climbing competitions, this bike is a formidable choice that delivers an exceptional riding experience.\n\n## Features\nSleek and aerodynamic design\nThe Stealth R1X Pro's aero tube shapes maximize speed and performance, making it faster on climbs and flats alike. The bike also features a streamlined Aeolus RSL bar/stem for improved front-end aerodynamics.\n\nDesigned for all riders\nThe Stealth R1X Pro is designed to provide an outstanding fit for riders of all genders, body types, riding styles, and abilities. It comes equipped with size-specific components to ensure a comfortable and efficient riding position for competitive riders.\n\n## Specifications\nFrameset\n- Frame: Ultralight carbon frame constructed with high-performance 500 Series ADV Carbon. It features Ride Tuned performance tube optimization, a tapered head tube, internal routing, DuoTrap S compatibility, flat mount disc brake mounts, and a 142x12mm thru axle.\n- Fork: Full carbon fork (Émonda SL) with a tapered carbon steerer, internal brake routing, flat mount disc brake mounts, and a 12x100mm thru axle.\n- Frame fit: H1.5 Race geometry.\n\nWheels\n- Front wheel: Bontrager Aeolus Elite 35 carbon wheel with a 35mm rim depth, ADV Carbon construction, Tubeless Ready compatibility, and a 100x12mm thru axle.\n- Rear wheel: Bontrager Aeolus Elite 35 carbon wheel with a 35mm rim depth, ADV Carbon construction, Tubeless Ready compatibility, Shimano 11/12-speed freehub, and a 142x12mm thru axle.\n- Front skewer: Bontrager Switch thru axle with a removable lever.\n- Rear skewer: Bontrager Switch thru axle with a removable lever.\n- Tire: Bontrager R2 Hard-Case Lite with an aramid bead, 60 tpi, and a size of 700x25c.\n- Maximum tire size: 28mm.\n\nDrivetrain\n- Shifter:\n - Size 47, 50, 52: Shimano Ultegra R8025 with short-reach levers, 11-speed.\n - Size 54, 56, 58, 60, 62: Shimano Ultegra R8020, 11-speed.\n- Front derailleur: Shimano Ultegra R8000, braze-on.\n- Rear derailleur: Shimano Ultegra R8000, short cage, with a maximum cog size of 30T.\n- Crank:\n - Size 47: Shimano Ultegra R8000 with 52/36 chainrings and a 165mm length.\n - Size 50, 52: Shimano Ultegra R8000 with 52/36 chainrings and a 170mm length.\n - Size 54, 56, 58: Shimano Ultegra R8000 with 52/36 chainrings and a 172.5mm length.\n - Size 60, 62: Shimano Ultegra R8000 with 52/36 chainrings and a 175mm length.\n- Bottom bracket: Praxis T47 threaded bottom bracket with internal bearings.\n- Cassette: Shimano Ultegra R8000, 11-30, 11-speed.\n- Chain: Shimano Ultegra HG701, 11-speed.\n- Maximum chainring size: 1x - 50T, 2x - 53/39.\n\nComponents\n- Saddle: Bontrager Aeolus Comp with steel rails and a width of 145mm.\n- Seatpost:\n - Size 47, 50, 52, 54: Bontrager carbon seatmast cap with a 20mm offset and a short length.\n - Size 56, 58, 60, 62: Bontrager carbon seatmast cap with a 20mm offset and a tall length.\n- Handlebar:\n - Size 47, 50: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 38cm.\n - Size 52: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 40cm.\n - Size 54, 56, 58: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 42cm.\n - Size 60, 62: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 44cm.\n- Handlebar tape: Bontrager Supertack Perf tape.\n- Stem:\n - Size 47: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 70mm.\n - Size 50: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 80mm.\n - Size 52, 54: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 90mm.\n - Size 56: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 100mm.\n - Size 58, 60, 62: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 110mm.\n- Brake: Shimano Ultegra hydraulic disc brakes with flat mount calipers.\n- Brake rotor: Shimano RT800 with centerlock mounting, 160mm diameter.\n\nWeight\n- Weight: 8.03 kg (17.71 lbs) for the 56cm frame.\n- Weight limit: The bike has a maximum total weight limit (combined weight of the bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\nPlease refer to the table below for the corresponding Stealth R1X Pro frame sizes, recommended rider height range, and inseam measurements:\n\n| Size | Rider Height | Inseam |\n|:----:|:---------------------:|:--------------:|\n| 47 | 152 - 158 cm (5'0\") | 71 - 75 cm |\n| 50 | 158 - 163 cm (5'2\") | 74 - 77 cm |\n| 52 | 163 - 168 cm (5'4\") | 76 - 79 cm |\n| 54 | 168 - 174 cm (5'6\") | 78 - 82 cm |\n| 56 | 174 - 180 cm (5'9\") | 81 - 85 cm |\n| 58 | 180 - 185 cm (5'11\") | 84 - 87 cm |\n| 60 | 185 - 190 cm (6'1\") | 86 - 90 cm |\n| 62 | 190 - 195 cm (6'3\") | 89 - 92 cm |\n\n## Geometry\nThe table below provides the geometry measurements for each frame size of the Stealth R1X Pro:\n\n| Frame size number | 47 cm | 50 cm | 52 cm | 54 cm | 56 cm | 58 cm | 60 cm | 62 cm |\n|-------------------------------|-------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 42.4 | 45.3 | 48.3 | 49.6 | 52.5 | 55.3 | 57.3 | 59.3 |\n| B — Seat tube angle | 74.6° | 74.6° | 74.2° | 73.7° | 73.3° | 73.0° | 72.8° | 72.5° |\n| C — Head tube length | 10.0 | 11.1 | 12.1 | 13.1 | 15.1 | 17.1 | 19.1 | 21.1 |\n| D — Head angle | 72.1° | 72.1° | 72.8° | 73.0° | 73.5° | 73.8° | 73.9° | 73.9° |\n| E — Effective top tube | 51.2 | 52.1 | 53.4 | 54.3 | 55.9 | 57.4 | 58.6 | 59.8 |\n| G — Bottom bracket drop | 7.2 | 7.2 | 7.2 | 7.0 | 7.0 | 6.8 | 6.8 | 6.8 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 | 41.0 | 41.1 | 41.1 | 41.2 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | 4.0 | 4.0 | 4.0 | 4.0 |\n| J — Trail | 6.8 | 6.2 | 5.8 | 5.6 | 5.8 | 5.7 | 5.6 | 5.6 |\n| K — Wheelbase | 97.2 | 97.4 | 97.7 | 98.1 | 98.3 | 99.2 | 100.1 | 101.0 |\n| L — Standover | 69.2 | 71.1 | 73.2 | 74.4 | 76.8 | 79.3 | 81.1 | 82.9 |\n| M — Frame reach | 37.3 | 37.8 | 38.3 | 38.6 | 39.1 | 39.6 | 39.9 | 40.3 |\n| N — Frame stack | 50.7 | 52.1 | 53.3 | 54.1 | 56.3 | 58.1 | 60.1 | 62.0 |\n| Saddle rail height min (short mast) | 55.5 | 58.5 | 61.5 | 64.0 | 67.0 | 69.0 | 71.0 | 73.0 |\n| Saddle rail height max (short mast) | 61.5 | 64.5 | 67.5 | 70.0 | 73.0 | 75.0 | 77.0 | 79.0 |\n| Saddle rail height min (tall mast) | 59.0 | 62.0 | 65.0 | 67.5 | 70.5 | 72.5 | 74.5 | 76.5 |\n| Saddle rail height max (tall mast) | 65.0 | 68.0 | 71.0 | 73.5 | 76.5 | 78.5 | 80.5 | 82.5 |", - "price": 2999.99, - "tags": [ - "bicycle", - "mountain bike", - "professional" - ] - }, - { - "name": "Avant SLR 6 Disc Pro", - "shortDescription": "Avant SLR 6 Disc Pro is a high-performance carbon road bike designed for riders who prioritize speed and handling. With its aero tube shaping, disc brakes, and lightweight carbon wheels, it offers the perfect balance of speed and control.", - "description": "## Overview\nIt's right for you if...\nYou're a rider who values exceptional performance on fast group rides and races, and you want a complete package that includes lightweight carbon wheels. The Avant SLR 6 Disc Pro is designed to provide the speed and aerodynamics you need to excel on any road.\n\nThe tech you get\nThe Avant SLR 6 Disc Pro features a lightweight 500 Series ADV Carbon frame and fork, Bontrager Aeolus Elite 35 carbon wheels, a full Shimano Ultegra 11-speed drivetrain, and powerful Ultegra disc brakes.\n\nThe final word\nThe standout feature of this bike is the combination of its aero frame, high-performance drivetrain, and top-quality carbon wheels. Whether you're racing, tackling challenging climbs, or participating in professional stage races, the Avant SLR 6 Disc Pro is a worthy choice that will enhance your performance.\n\n## Features\nAll-new aero design\nThe Avant SLR 6 Disc Pro features innovative aero tube shapes that provide an advantage in all riding conditions, whether it's climbing or riding on flat roads. Additionally, it is equipped with a sleek new Aeolus RSL bar/stem that enhances front-end aero performance.\n\nAwesome bikes for everyone\nThe Avant SLR 6 Disc Pro is designed with the belief that every rider, regardless of gender, body type, riding style, or ability, deserves a great bike. It is equipped with size-specific components that ensure a perfect fit for competitive riders of all genders.\n\n## Specifications\nFrameset\n- Frame: Ultralight 500 Series ADV Carbon, Ride Tuned performance tube optimization, tapered head tube, internal routing, DuoTrap S compatible, flat mount disc, 142x12mm thru axle\n- Fork: Avant SL full carbon, tapered carbon steerer, internal brake routing, flat mount disc, 12x100mm thru axle\n- Frame fit: H1.5 Race\n\nWheels\n- Front wheel: Bontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, 100x12mm thru axle\n- Rear wheel: Bontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, Shimano 11/12-speed freehub, 142x12mm thru axle\n- Front skewer: Bontrager Switch thru axle, removable lever\n- Rear skewer: Bontrager Switch thru axle, removable lever\n- Tire: Bontrager R2 Hard-Case Lite, aramid bead, 60 tpi, 700x25c\n- Max tire size: 28mm\n\nDrivetrain\n- Shifter: \n - Size 47, 50, 52: Shimano Ultegra R8025, short-reach lever, 11-speed\n - Size 54, 56, 58, 60, 62: Shimano Ultegra R8020, 11-speed\n- Front derailleur: Shimano Ultegra R8000, braze-on\n- Rear derailleur: Shimano Ultegra R8000, short cage, 30T max cog\n- Crank: \n - Size 47: Shimano Ultegra R8000, 52/36, 165mm length\n - Size 50, 52: Shimano Ultegra R8000, 52/36, 170mm length\n - Size 54, 56, 58: Shimano Ultegra R8000, 52/36, 172.5mm length\n - Size 60, 62: Shimano Ultegra R8000, 52/36, 175mm length\n- Bottom bracket: Praxis, T47 threaded, internal bearing\n- Cassette: Shimano Ultegra R8000, 11-30, 11-speed\n- Chain: Shimano Ultegra HG701, 11-speed\n- Max chainring size: 1x: 50T, 2x: 53/39\n\nComponents\n- Saddle: Bontrager Aeolus Comp, steel rails, 145mm width\n- Seatpost: \n - Size 47, 50, 52, 54: Bontrager carbon seatmast cap, 20mm offset, short length\n - Size 56, 58, 60, 62: Bontrager carbon seatmast cap, 20mm offset, tall length\n- Handlebar: \n - Size 47, 50: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 38cm width\n - Size 52: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 40cm width\n - Size 54, 56, 58: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 42cm width\n - Size 60, 62: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 44cm width\n- Handlebar tape: Bontrager Supertack Perf tape\n- Stem: \n - Size 47: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 70mm length\n - Size 50: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 80mm length\n - Size 52, 54: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 90mm length\n - Size 56: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 100mm length\n - Size 58, 60, 62: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 110mm length\n- Brake: Shimano Ultegra hydraulic disc, flat mount\n- Brake rotor: Shimano RT800, centerlock, 160mm\n\nWeight\n- Weight: 56 - 8.03 kg / 17.71 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 47 | 152 - 158 cm 5'0\" - 5'2\" | 71 - 75 cm 28\" - 30\" |\n| 50 | 158 - 163 cm 5'2\" - 5'4\" | 74 - 77 cm 29\" - 30\" |\n| 52 | 163 - 168 cm 5'4\" - 5'6\" | 76 - 79 cm 30\" - 31\" |\n| 54 | 168 - 174 cm 5'6\" - 5'9\" | 78 - 82 cm 31\" - 32\" |\n| 56 | 174 - 180 cm 5'9\" - 5'11\" | 81 - 85 cm 32\" - 33\" |\n| 58 | 180 - 185 cm 5'11\" - 6'1\" | 84 - 87 cm 33\" - 34\" |\n| 60 | 185 - 190 cm 6'1\" - 6'3\" | 86 - 90 cm 34\" - 35\" |\n| 62 | 190 - 195 cm 6'3\" - 6'5\" | 89 - 92 cm 35\" - 36\" |\n\n## Geometry\n| Frame size number | 47 cm | 50 cm | 52 cm | 54 cm | 56 cm | 58 cm | 60 cm | 62 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 42.4 | 45.3 | 48.3 | 49.6 | 52.5 | 55.3 | 57.3 | 59.3 |\n| B — Seat tube angle | 74.6° | 74.6° | 74.2° | 73.7° | 73.3° | 73.0° | 72.8° | 72.5° |\n| C — Head tube length | 10.0 | 11.1 | 12.1 | 13.1 | 15.1 | 17.1 | 19.1 | 21.1 |\n| D — Head angle | 72.1° | 72.1° | 72.8° | 73.0° | 73.5° | 73.8° | 73.9° | 73.9° |\n| E — Effective top tube | 51.2 | 52.1 | 53.4 | 54.3 | 55.9 | 57.4 | 58.6 | 59.8 |\n| G — Bottom bracket drop | 7.2 | 7.2 | 7.2 | 7.0 | 7.0 | 6.8 | 6.8 | 6.8 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 | 41.0 | 41.1 | 41.1 | 41.2 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | 4.0 | 4.0 | 4.0 | 4.0 |\n| J — Trail | 6.8 | 6.2 | 5.8 | 5.6 | 5.8 | 5.7 | 5.6 | 5.6 |\n| K — Wheelbase | 97.2 | 97.4 | 97.7 | 98.1 | 98.3 | 99.2 | 100.1 | 101.0 |\n| L — Standover | 69.2 | 71.1 | 73.2 | 74.4 | 76.8 | 79.3 | 81.1 | 82.9 |\n| M — Frame reach | 37.3 | 37.8 | 38.3 | 38.6 | 39.1 | 39.6 | 39.9 | 40.3 |\n| N — Frame stack | 50.7 | 52.1 | 53.3 | 54.1 | 56.3 | 58.1 | 60.1 | 62.0 |\n| Saddle rail height min (w/short mast) | 55.5 | 58.5 | 61.5 | 64.0 | 67.0 | 69.0 | 71.0 | 73.0 |\n| Saddle rail height max (w/short mast) | 61.5 | 64.5 | 67.5 | 70.0 | 73.0 | 75.0 | 77.0 | 79.0 |\n| Saddle rail height min (w/tall mast) | 59.0 | 62.0 | 65.0 | 67.5 | 70.5 | 72.5 | 74.5 | 76.5 |\n| Saddle rail height max (w/tall mast) | 65.0 | 68.0 | 71.0 | 73.5 | 76.5 | 78.5 | 80.5 | 82.5 |", - "price": 999.99, - "tags": [ - "bicycle", - "city bike", - "professional" - ] - } -] \ No newline at end of file + { + "name": "E-Adrenaline 8.0 EX1", + "shortDescription": "a versatile and comfortable e-MTB designed for adrenaline enthusiasts who want to explore all types of terrain. It features a powerful motor and advanced suspension to provide a smooth and responsive ride, with a variety of customizable settings to fit any rider's needs.", + "description": "## Overview\r\nIt's right for you if...\r\nYou want to push your limits on challenging trails and terrain, with the added benefit of an electric assist to help you conquer steep climbs and rough terrain. You also want a bike with a comfortable and customizable fit, loaded with high-quality components and technology.\r\n\r\nThe tech you get\r\nA lightweight, full ADV Mountain Carbon frame with a customizable geometry, including an adjustable head tube and chainstay length. A powerful and efficient motor with a 375Wh battery that can assist up to 28 mph when it's on, and provides a smooth and seamless transition when it's off. A SRAM EX1 8-speed drivetrain, a RockShox Lyrik Ultimate fork, and a RockShox Super Deluxe Ultimate rear shock.\r\n\r\nThe final word\r\nOur E-Adrenaline 8.0 EX1 is the perfect bike for adrenaline enthusiasts who want to explore all types of terrain. It's versatile, comfortable, and loaded with advanced technology to provide a smooth and responsive ride, no matter where your adventures take you.\r\n\r\n\r\n## Features\r\nVersatile and customizable\r\nThe E-Adrenaline 8.0 EX1 features a customizable geometry, including an adjustable head tube and chainstay length, so you can fine-tune your ride to fit your needs and preferences. It also features a variety of customizable settings, including suspension tuning, motor assistance levels, and more.\r\n\r\nPowerful and efficient\r\nThe bike is equipped with a powerful and efficient motor that provides a smooth and seamless transition between human power and electric assist. It can assist up to 28 mph when it's on, and provides zero drag when it's off.\r\n\r\nAdvanced suspension\r\nThe E-Adrenaline 8.0 EX1 features a RockShox Lyrik Ultimate fork and a RockShox Super Deluxe Ultimate rear shock, providing advanced suspension technology to absorb shocks and bumps on any terrain. The suspension is also customizable to fit your riding style and preferences.\r\n\r\n\r\n## Specs\r\nFrameset\r\nFrame ADV Mountain Carbon main frame & stays, adjustable head tube and chainstay length, tapered head tube, Knock Block, Control Freak internal routing, Boost148, 150mm travel\r\nFork RockShox Lyrik Ultimate, DebonAir spring, Charger 2.1 RC2 damper, remote lockout, tapered steerer, 42mm offset, Boost110, 15mm Maxle Stealth, 160mm travel\r\nShock RockShox Super Deluxe Ultimate, DebonAir spring, Thru Shaft 3-position damper, 230x57.5mm\r\n\r\nWheels\r\nWheel front Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 6-bolt, Boost110, 15mm thru axle\r\nWheel rear Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 54T Rapid Drive, 6-bolt, Shimano MicroSpline freehub, Boost148, 12mm thru axle\r\nSkewer rear Bontrager Switch thru axle, removable lever\r\nTire Bontrager XR5 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.50''\r\nTire part Bontrager TLR sealant, 6oz\r\n\r\nDrivetrain\r\nShifter SRAM EX1, 8 speed\r\nRear derailleur SRAM EX1, 8 speed\r\nCrank Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nChainring SRAM EX1, 18T, steel\r\nCassette SRAM EX1, 11-48, 8 speed\r\nChain SRAM EX1, 8 speed\r\n\r\nComponents\r\nSaddle Bontrager Arvada, hollow chromoly rails, 138mm width\r\nSeatpost Bontrager Line Elite Dropper, internal routing, 31.6mm\r\nHandlebar Bontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\r\nGrips Bontrager XR Trail Elite, alloy lock-on\r\nStem Bontrager Line Pro, 35mm, Knock Block, Blendr compatible, 0 degree, 50mm length\r\nHeadset Knock Block Integrated, 62-degree radius, cartridge bearing, 1-1\/8'' top, 1.5'' bottom\r\nBrake SRAM G2 RSC hydraulic disc, carbon levers\r\nBrake rotor SRAM Centerline, centerlock, round edge, 200mm\r\n\r\nAccessories\r\nE-bike system Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nBattery Bosch PowerTube 625, 625Wh\r\nCharger Bosch 4A standard charger\r\nController Bosch Kiox with Anti-theft solution, Bluetooth connectivity, 1.9'' display\r\nTool Bontrager Switch thru axle, removable lever\r\n\r\nWeight\r\nWeight M - 20.25 kg \/ 44.6 lbs (with TLR sealant, no tubes)\r\nWeight limit This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\r\n\r\n## Sizing & fit\r\n\r\n| Size | Rider Height | Inseam |\r\n|:----:|:------------------------:|:--------------------:|\r\n| S | 155 - 170 cm 5'1\" - 5'7\" | 73 - 80 cm 29\" - 31.5\" |\r\n| M | 163 - 178 cm 5'4\" - 5'10\" | 77 - 83 cm 30.5\" - 32.5\" |\r\n| L | 176 - 191 cm 5'9\" - 6'3\" | 83 - 89 cm 32.5\" - 35\" |\r\n| XL | 188 - 198 cm 6'2\" - 6'6\" | 88 - 93 cm 34.5\" - 36.5\" |\r\n\r\n\r\n## Geometry\r\n\r\nAll measurements provided in cm unless otherwise noted.\r\nSizing table\r\n| Frame size letter | S | M | L | XL |\r\n|---------------------------|-------|-------|-------|-------|\r\n| Actual frame size | 15.8 | 17.8 | 19.8 | 21.8 |\r\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\r\n| A \u2014 Seat tube | 40.0 | 42.5 | 47.5 | 51.0 |\r\n| B \u2014 Seat tube angle | 72.5\u00B0 | 72.8\u00B0 | 73.0\u00B0 | 73.0\u00B0 |\r\n| C \u2014 Head tube length | 9.5 | 10.5 | 11.0 | 11.5 |\r\n| D \u2014 Head angle | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 |\r\n| E \u2014 Effective top tube | 59.0 | 62.0 | 65.0 | 68.0 |\r\n| F \u2014 Bottom bracket height | 32.5 | 32.5 | 32.5 | 32.5 |\r\n| G \u2014 Bottom bracket drop | 5.5 | 5.5 | 5.5 | 5.5 |\r\n| H \u2014 Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\r\n| I \u2014 Offset | 4.5 | 4.5 | 4.5 | 4.5 |\r\n| J \u2014 Trail | 11.0 | 11.0 | 11.0 | 11.0 |\r\n| K \u2014 Wheelbase | 113.0 | 117.0 | 120.0 | 123.0 |\r\n| L \u2014 Standover | 77.0 | 77.0 | 77.0 | 77.0 |\r\n| M \u2014 Frame reach | 41.0 | 44.5 | 47.5 | 50.0 |\r\n| N \u2014 Frame stack | 61.0 | 62.0 | 62.5 | 63.0 |", + "price": 1499.99, + "tags": [ + "bicycle" + ] + }, + { + "name": "Enduro X Pro", + "shortDescription": "The Enduro X Pro is the ultimate mountain bike for riders who demand the best. With its full carbon frame and top-of-the-line components, this bike is ready to tackle any trail, from technical downhill descents to grueling uphill climbs.", + "text": "## Overview\nIt's right for you if...\nYou're an experienced mountain biker who wants a high-performance bike that can handle any terrain. You want a bike with the best components available, including a full carbon frame, suspension system, and hydraulic disc brakes.\n\nThe tech you get\nOur top-of-the-line full carbon frame with aggressive geometry and a slack head angle for maximum control. It's equipped with a Fox Factory suspension system with 170mm of travel in the front and 160mm in the rear, a Shimano XTR 12-speed drivetrain, and hydraulic disc brakes for maximum stopping power. The bike also features a dropper seatpost for easy adjustments on the fly.\n\nThe final word\nThe Enduro X Pro is the ultimate mountain bike for riders who demand the best. With its full carbon frame, top-of-the-line components, and aggressive geometry, this bike is ready to take on any trail. Whether you're a seasoned pro or just starting out, the Enduro X Pro will help you take your riding to the next level.\n\n## Features\nFull carbon frame\nAggressive geometry with a slack head angle\nFox Factory suspension system with 170mm of travel in the front and 160mm in the rear\nShimano XTR 12-speed drivetrain\nHydraulic disc brakes for maximum stopping power\nDropper seatpost for easy adjustments on the fly\n\n## Specifications\nFrameset\nFrame\tFull carbon frame\nFork\tFox Factory suspension system with 170mm of travel\nRear suspension\tFox Factory suspension system with 160mm of travel\n\nWheels\nWheel size\t27.5\" or 29\"\nTires\tTubeless-ready Maxxis tires\n\nDrivetrain\nShifters\tShimano XTR 12-speed\nFront derailleur\tN/A\nRear derailleur\tShimano XTR\nCrankset\tShimano XTR\nCassette\tShimano XTR 12-speed\nChain\tShimano XTR\n\nComponents\nBrakes\tHydraulic disc brakes\nHandlebar\tAlloy handlebar\nStem\tAlloy stem\nSeatpost\tDropper seatpost\n\nAccessories\nPedals\tNot included\n\nWeight\nWeight\tApproximately 27-29 lbs\n\n## Sizing\n| Size | Rider Height |\n|:----:|:-------------------------:|\n| S | 5'4\" - 5'8\" (162-172cm) |\n| M | 5'8\" - 5'11\" (172-180cm) |\n| L | 5'11\" - 6'3\" (180-191cm) |\n| XL | 6'3\" - 6'6\" (191-198cm) |\n\n## Geometry\n| Size | S | M | L | XL |\n|:----:|:---------------:|:---------------:|:-----------------:|:---------------:|\n| A - Seat tube length | 390mm | 425mm | 460mm | 495mm |\n| B - Effective top tube length | 585mm | 610mm | 635mm | 660mm |\n| C - Head tube angle | 65.5° | 65.5° | 65.5° | 65.5° |\n| D - Seat tube angle | 76° | 76° | 76° | 76° |\n| E - Chainstay length | 435mm | 435mm | 435mm | 435mm |\n| F - Head tube length | 100mm | 110mm | 120mm | 130mm |\n| G - BB drop | 20mm | 20mm | 20mm | 20mm |\n| H - Wheelbase | 1155mm | 1180mm | 1205mm | 1230mm |\n| I - Standover height | 780mm | 800mm | 820mm | 840mm |\n| J - Reach | 425mm | 450mm | 475mm | 500mm |\n| K - Stack | 610mm | 620mm | 630mm | 640mm |", + "price": 599.99, + "tags": [ + "bicycle" + ] + }, + { + "name": "Blaze X1", + "shortDescription": "Blaze X1 is a high-performance road bike that offers superior speed and agility, making it perfect for competitive racing or fast-paced group rides. The bike features a lightweight carbon frame, aerodynamic tube shapes, a 12-speed Shimano Ultegra drivetrain, and hydraulic disc brakes for precise stopping power. With its sleek design and cutting-edge technology, Blaze X1 is a bike that is built to perform and dominate on any road.", + "description": "## Overview\nIt's right for you if...\nYou're a competitive road cyclist or an enthusiast who enjoys fast-paced group rides. You want a bike that is lightweight, agile, and delivers exceptional speed.\n\nThe tech you get\nBlaze X1 features a lightweight carbon frame with a tapered head tube and aerodynamic tube shapes for maximum speed and efficiency. The bike is equipped with a 12-speed Shimano Ultegra drivetrain for smooth and precise shifting, Shimano hydraulic disc brakes for powerful and reliable stopping power, and Bontrager Aeolus Elite 35 carbon wheels for increased speed and agility.\n\nThe final word\nBlaze X1 is a high-performance road bike that is designed to deliver exceptional speed and agility. With its cutting-edge technology and top-of-the-line components, it's a bike that is built to perform and dominate on any road.\n\n## Features\nSpeed and efficiency\nBlaze X1's lightweight carbon frame and aerodynamic tube shapes offer maximum speed and efficiency, allowing you to ride faster and farther with ease.\n\nPrecision stopping power\nShimano hydraulic disc brakes provide precise and reliable stopping power, even in wet or muddy conditions.\n\nAgility and control\nBontrager Aeolus Elite 35 carbon wheels make Blaze X1 incredibly agile and responsive, allowing you to navigate tight turns and corners with ease.\n\nSmooth and precise shifting\nThe 12-speed Shimano Ultegra drivetrain offers smooth and precise shifting, so you can easily find the right gear for any terrain.\n\n## Specifications\nFrameset\nFrame\tADV Carbon, tapered head tube, BB90, direct mount rim brakes, internal cable routing, DuoTrap S compatible, 130x9mm QR\nFork\tADV Carbon, tapered steerer, direct mount rim brakes, internal brake routing, 100x9mm QR\n\nWheels\nWheel front\tBontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, 100x9mm QR\nWheel rear\tBontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, Shimano 11-speed freehub, 130x9mm QR\nTire front\tBontrager R3 Hard-Case Lite, aramid bead, 120 tpi, 700x25c\nTire rear\tBontrager R3 Hard-Case Lite, aramid bead, 120 tpi, 700x25c\nMax tire size\t25c Bontrager tires (with at least 4mm of clearance to frame)\n\nDrivetrain\nShifter\tShimano Ultegra R8020, 12 speed\nFront derailleur\tShimano Ultegra R8000, braze-on\nRear derailleur\tShimano Ultegra R8000, short cage, 30T max cog\nCrank\tSize: 50, 52, 54\nShimano Ultegra R8000, 50/34 (compact), 170mm length\nSize: 56, 58, 60, 62\nShimano Ultegra R8000, 50/34 (compact), 172.5mm length\nBottom bracket\tBB90, Shimano press-fit\nCassette\tShimano Ultegra R8000, 11-30, 12 speed\nChain\tShimano Ultegra HG701, 12 speed\n\nComponents\nSaddle\tBontrager Montrose Elite, titanium rails, 138mm width\nSeatpost\tBontrager carbon seatmast cap, 20mm offset\nHandlebar\tBontrager Elite Aero VR-CF, alloy, 31.8mm, internal cable routing, 40cm width\nGrips\tBontrager Supertack Perf tape\nStem\tBontrager Elite, 31.8mm, Blendr-compatible, 7 degree, 80mm length\nBrake Shimano Ultegra hydraulic disc brake\n\nWeight\nWeight\t56 - 8.91 kg / 19.63 lbs (with tubes)\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider height |\n|------|-------------|\n| 50 | 162-166cm |\n| 52 | 165-170cm |\n| 54 | 168-174cm |\n| 56 | 174-180cm |\n| 58 | 179-184cm |\n| 60 | 184-189cm |\n| 62 | 189-196cm |\n\n## Geometry\n| Frame size | 50cm | 52cm | 54cm | 56cm | 58cm | 60cm | 62cm |\n|------------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A - Seat tube | 443mm | 460mm | 478mm | 500mm | 520mm | 540mm | 560mm |\n| B - Seat tube angle | 74.1° | 73.9° | 73.7° | 73.4° | 73.2° | 73.0° | 72.8° |\n| C - Head tube length | 100mm | 110mm | 130mm | 150mm | 170mm | 190mm | 210mm |\n| D - Head angle | 71.4° | 72.0° | 72.5° | 73.0° | 73.3° | 73.6° | 73.8° |\n| E - Effective top tube | 522mm | 535mm | 547mm | 562mm | 577mm | 593mm | 610mm |\n| F - Bottom bracket height | 268mm | 268mm | 268mm | 268mm | 268mm | 268mm | 268mm |\n| G - Bottom bracket drop | 69mm | 69mm | 69mm | 69mm | 69mm | 69mm | 69mm |\n| H - Chainstay length | 410mm | 410mm | 410mm | 410mm | 410mm | 410mm | 410mm |\n| I - Offset | 50mm | 50mm | 50mm | 50mm | 50mm | 50mm | 50mm |\n| J - Trail | 65mm | 62mm | 59mm | 56mm | 55mm | 53mm | 52mm |\n| K - Wheelbase | 983mm | 983mm | 990mm | 1005mm | 1019mm | 1036mm | 1055mm |\n| L - Standover | 741mm | 765mm | 787mm | 806mm | 825mm | 847mm | 869mm |", + "price": 799.99, + "tags": [ + "bicycle", + "mountain bike" + ] + }, + { + "name": "Celerity X5", + "shortDescription": "Celerity X5 is a versatile and reliable road bike that is designed for experienced and amateur riders alike. It's designed to provide smooth and comfortable rides over long distances. With an ultra-lightweight and responsive carbon fiber frame, Shimano 105 groupset, hydraulic disc brakes, and 28mm wide tires, this bike ensures efficient power transfer, precise handling, and superior stopping power.", + "description": "## Overview\n\nIt's right for you if... \nYou are looking for a high-performance road bike that offers a perfect balance of speed, comfort, and control. You enjoy long-distance rides and need a bike that is designed to handle various road conditions with ease. You also appreciate the latest technology and reliable components that make your riding experience more enjoyable.\n\nThe tech you get \nCelerity X5 is equipped with a full carbon fiber frame that ensures maximum strength and durability while keeping the weight down. It features a Shimano 105 groupset with 11-speed gearing for precise and efficient shifting. Hydraulic disc brakes offer superior stopping power, and 28mm wide tires provide comfort and stability on various road surfaces. Internal cable routing enhances the bike's sleek appearance.\n\nThe final word \nIf you are looking for a high-performance road bike that offers comfort, speed, and control, Celerity X5 is the perfect choice. With its lightweight carbon fiber frame, reliable components, and advanced technology, this bike is designed to help you enjoy long-distance rides with ease.\n\n## Features \n\nLightweight and responsive \nCelerity X5 comes with a full carbon fiber frame that is not only lightweight but also responsive, providing excellent handling and control.\n\nHydraulic disc brakes \nThis bike is equipped with hydraulic disc brakes that provide superior stopping power in all weather conditions, ensuring your safety and confidence on the road.\n\nComfortable rides \nThe 28mm wide tires and carbon seat post provide ample cushioning, ensuring a smooth and comfortable ride over long distances.\n\nSleek appearance \nThe bike's internal cable routing enhances its sleek appearance while also protecting the cables from the elements, ensuring smooth shifting for longer periods.\n\n## Specifications \n\nFrameset \nFrame\tCelerity X5 Full Carbon Fiber Frame, Internal Cable Routing, Tapered Headtube, Press Fit Bottom Bracket, 12x142mm Thru-Axle \nFork\tCelerity X5 Full Carbon Fiber Fork, Internal Brake Routing, 12x100mm Thru-Axle \n\nWheels \nWheelset\tAlexRims CXD7 Wheelset \nTire\tSchwalbe Durano Plus 700x28mm \nInner Tubes\tSchwalbe SV15 700x18-28mm \nSkewers\tCelerity X5 Thru-Axle Skewers \n\nDrivetrain \nShifter\tShimano 105 R7025 Hydraulic Disc Shifters \nFront Derailleur\tShimano 105 R7000 \nRear Derailleur\tShimano 105 R7000 \nCrankset\tShimano 105 R7000 50-34T \nBottom Bracket\tShimano BB72-41B \nCassette\tShimano 105 R7000 11-30T \nChain\tShimano HG601 11-Speed Chain \n\nComponents \nSaddle\tSelle Royal Asphalt Saddle \nSeatpost\tCelerity X5 Carbon Seatpost \nHandlebar\tCelerity X5 Compact Handlebar \nStem\tCelerity X5 Aluminum Stem \nHeadset\tFSA Orbit IS-2 \n\nBrakes \nBrakes\tShimano 105 R7025 Hydraulic Disc Brakes \nRotors\tShimano SM-RT70 160mm Rotors \n\nAccessories \nPedals\tCelerity X5 Road Pedals \n\nWeight \nWeight\t8.2 kg / 18.1 lbs \nWeight Limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 120 kg (265 lbs).\n\n## Sizing \n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 49 | 155 - 162 cm 5'1\" - 5'4\" | 71 - 76 cm 28\" - 30\" |\n| 52 | 162 - 170 cm 5'4\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| 54 | 170 - 178 cm 5'7\" - 5'10\" | 77 - 83 cm 30\" - 32\" |\n| 56 | 178 - 185 cm 5'10\" - 6'1\" | 82 - 88 cm 32\" - 34\" |\n| 58 | 185 - 193 cm 6'1\" - 6'4\" | 86 - 92 cm 34\" - 36\" |\n| 61 | 193 - 200 cm 6'4\" - 6'7\" | 90 - 95 cm 35\" - 37\" |\n\n## Geometry \n| Frame size number | 49 cm | 52 cm | 54 cm | 56 cm | 58 cm | 61 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 47.5 | 50.0 | 52.0 | 54.0 | 56.0 | 58.5 |\n| B — Seat tube angle | 75.0° | 74.5° | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 12.0 | 14.5 | 16.5 | 18.5 | 20.5 | 23.5 |\n| D — Head angle | 70.0° | 71.0° | 71.5° | 72.0° | 72.5° | 72.5° |\n| E — Effective top tube | 52.5 | 53.5 | 54.5 | 56.0 | 57.5 | 59.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 |\n| K — Wheelbase | 98.4 | 98.9 | 99.8 | 100.8 | 101.7 | 103.6 |\n| L — Standover | 72.0 | 74.0 | 76.0 | 78.0 | 80.0 | 82.0 |\n| M — Frame reach | 36.2 | 36.8 | 37.3 | 38.1 | 38.6 | 39.4 |\n| N — Frame stack | 52.0 | 54.3 | 56.2 | 58.1 | 59.8 | 62.4 |\n| Saddle rail height min | 67.0 | 69.5 | 71.5 | 74.0 | 76.0 | 78.0 |\n| Saddle rail height max | 75.0 | 77.5 | 79.5 | 82.0 | 84.0 | 86.0 |", + "price": 399.99, + "tags": [ + "bicycle", + "city bike" + ] + }, + { + "name": "Velocity V8", + "shortDescription": "Velocity V8 is a high-performance road bike that is designed to deliver speed, agility, and control on the road. With its lightweight aluminum frame, carbon fiber fork, Shimano Tiagra groupset, and hydraulic disc brakes, this bike is perfect for experienced riders who are looking for a fast and responsive bike that can handle various road conditions.", + "description": "## Overview\n\nIt's right for you if... \nYou are an experienced rider who is looking for a high-performance road bike that is lightweight, agile, and responsive. You want a bike that can handle long-distance rides, steep climbs, and fast descents with ease. You also appreciate the latest technology and reliable components that make your riding experience more enjoyable.\n\nThe tech you get \nVelocity V8 features a lightweight aluminum frame with a carbon fiber fork that ensures a comfortable ride without sacrificing stiffness and power transfer. It comes with a Shimano Tiagra groupset with 10-speed gearing for precise and efficient shifting. Hydraulic disc brakes offer superior stopping power in all weather conditions, while 28mm wide tires provide comfort and stability on various road surfaces. Internal cable routing enhances the bike's sleek appearance.\n\nThe final word \nIf you are looking for a high-performance road bike that is lightweight, fast, and responsive, Velocity V8 is the perfect choice. With its lightweight aluminum frame, reliable components, and advanced technology, this bike is designed to help you enjoy fast and comfortable rides on the road.\n\n## Features \n\nLightweight and responsive \nVelocity V8 comes with a lightweight aluminum frame that is not only lightweight but also responsive, providing excellent handling and control.\n\nHydraulic disc brakes \nThis bike is equipped with hydraulic disc brakes that provide superior stopping power in all weather conditions, ensuring your safety and confidence on the road.\n\nComfortable rides \nThe 28mm wide tires and carbon fork provide ample cushioning, ensuring a smooth and comfortable ride over long distances.\n\nSleek appearance \nThe bike's internal cable routing enhances its sleek appearance while also protecting the cables from the elements, ensuring smooth shifting for longer periods.\n\n## Specifications \n\nFrameset \nFrame\tVelocity V8 Aluminum Frame, Internal Cable Routing, Tapered Headtube, Press Fit Bottom Bracket, 12x142mm Thru-Axle \nFork\tVelocity V8 Carbon Fiber Fork, Internal Brake Routing, 12x100mm Thru-Axle \n\nWheels \nWheelset\tAlexRims CXD7 Wheelset \nTire\tSchwalbe Durano Plus 700x28mm \nInner Tubes\tSchwalbe SV15 700x18-28mm \nSkewers\tVelocity V8 Thru-Axle Skewers \n\nDrivetrain \nShifter\tShimano Tiagra Hydraulic Disc Shifters \nFront Derailleur\tShimano Tiagra \nRear Derailleur\tShimano Tiagra \nCrankset\tShimano Tiagra 50-34T \nBottom Bracket\tShimano BB-RS500-PB \nCassette\tShimano Tiagra 11-32T \nChain\tShimano HG54 10-Speed Chain \n\nComponents \nSaddle\tVelocity V8 Saddle \nSeatpost\tVelocity V8 Aluminum Seatpost \nHandlebar\tVelocity V8 Compact Handlebar \nStem\tVelocity V8 Aluminum Stem \nHeadset\tFSA Orbit IS-2 \n\nBrakes \nBrakes\tShimano Tiagra Hydraulic Disc Brakes \nRotors\tShimano SM-RT64 160mm Rotors \n\nAccessories \nPedals\tVelocity V8 Road Pedals \n\nWeight \nWeight\t9.4 kg / 20.7 lbs \nWeight Limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 120 kg (265 lbs).\n\n## Sizing \n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 49 | 155 - 162 cm 5'1\" - 5'4\" | 71 - 76 cm 28\" - 30\" |\n| 52 | 162 - 170 cm 5'4\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| 54 | 170 - 178 cm 5'7\" - 5'10\" | 77 - 83 cm 30\" - 32\" |\n| 56 | 178 - 185 cm 5'10\" - 6'1\" | 82 - 88 cm 32\" - 34\" |\n| 58 | 185 - 193 cm 6'1\" - 6'4\" | 86 - 92 cm 34\" - 36\" |\n| 61 | 193 - 200 cm 6'4\" - 6'7\" | 90 - 95 cm 35\" - 37\" |\n\n## Geometry \n| Frame size number | 49 cm | 52 cm | 54 cm | 56 cm | 58 cm | 61 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 47.5 | 50.0 | 52.0 | 54.0 | 56.0 | 58.5 |\n| B — Seat tube angle | 75.0° | 74.5° | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 12.0 | 14.5 | 16.5 | 18.5 | 20.5 | 23.5 |\n| D — Head angle | 70.0° | 71.0° | 71.5° | 72.0° | 72.5° | 72.5° |\n| E — Effective top tube | 52.5 | 53.5 | 54.5 | 56.0 | 57.5 | 59.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 |\n| K — Wheelbase | 98.4 | 98.9 | 99.8 | 100.8 | 101.7 | 103.6 |\n| L — Standover | 72.0 | 74.0 | 76.0 | 78.0 | 80.0 | 82.0 |\n| M — Frame reach | 36.2 | 36.8 | 37.3 | 38.1 | 38.6 | 39.4 |\n| N — Frame stack | 52.0 | 54.3 | 56.2 | 58.1 | 59.8 | 62.4 |\n| Saddle rail height min | 67.0 | 69.5 | 71.5 | 74.0 | 76.0 | 78.0 |\n| Saddle rail height max | 75.0 | 77.5 | 79.5 | 82.0 | 84.0 | 86.0 |", + "price": 1899.99, + "tags": [ + "bicycle", + "electric bike" + ] + }, + { + "name": "VeloCore X9 eMTB", + "shortDescription": "The VeloCore X9 eMTB is a light, agile and versatile electric mountain bike designed for adventure and performance. Its purpose-built frame and premium components offer an exhilarating ride experience on both technical terrain and smooth singletrack.", + "description": "## Overview\nIt's right for you if...\nYou love exploring new trails and testing your limits on challenging terrain. You want an electric mountain bike that offers power when you need it, without sacrificing performance or agility. You're looking for a high-quality bike with top-notch components and a sleek design.\n\nThe tech you get\nA lightweight, full carbon frame with custom geometry, a 140mm RockShox Pike Ultimate fork with Charger 2.1 damper, and a Fox Float DPS Performance shock. A Shimano STEPS E8000 motor and 504Wh battery that provide up to 62 miles of range and 20 mph assistance. A Shimano XT 12-speed drivetrain, Shimano SLX brakes, and DT Swiss wheels.\n\nThe final word\nThe VeloCore X9 eMTB delivers power and agility in equal measure. It's a versatile and capable electric mountain bike that can handle any trail with ease. With premium components, a custom carbon frame, and a sleek design, this bike is built for adventure.\n\n## Features\nAgile and responsive\n\nThe VeloCore X9 eMTB is designed to be nimble and responsive on the trail. Its custom carbon frame offers a perfect balance of stiffness and compliance, while the suspension system provides smooth and stable performance on technical terrain.\n\nPowerful and efficient\n\nThe Shimano STEPS E8000 motor and 504Wh battery provide up to 62 miles of range and 20 mph assistance. The motor delivers smooth and powerful performance, while the battery offers reliable and consistent power for long rides.\n\nCustomizable ride experience\n\nThe VeloCore X9 eMTB comes with an intuitive and customizable Shimano STEPS display that allows you to adjust the level of assistance, monitor your speed and battery life, and customize your ride experience to suit your needs.\n\nPremium components\n\nThe VeloCore X9 eMTB is equipped with high-end components, including a Shimano XT 12-speed drivetrain, Shimano SLX brakes, and DT Swiss wheels. These components offer reliable and precise performance, allowing you to push your limits with confidence.\n\n## Specs\nFrameset\nFrame\tVeloCore carbon fiber frame, Boost, tapered head tube, internal cable routing, 140mm travel\nFork\tRockShox Pike Ultimate, Charger 2.1 damper, DebonAir spring, 15x110mm Boost Maxle Ultimate, 46mm offset, 140mm travel\nShock\tFox Float DPS Performance, EVOL, 3-position adjust, Kashima Coat, 210x50mm\n\nWheels\nWheel front\tDT Swiss XM1700 Spline, 30mm internal width, 15x110mm Boost axle\nWheel rear\tDT Swiss XM1700 Spline, 30mm internal width, Shimano Microspline driver, 12x148mm Boost axle\nTire front\tMaxxis Minion DHF, 29x2.5\", EXO+ casing, tubeless ready\nTire rear\tMaxxis Minion DHR II, 29x2.4\", EXO+ casing, tubeless ready\n\nDrivetrain\nShifter\tShimano XT M8100, 12-speed\nRear derailleur\tShimano XT M8100, Shadow Plus, long cage, 51T max cog\nCrankset\tShimano STEPS E8000, 165mm length, 34T chainring\nCassette\tShimano XT M8100, 10-51T, 12-speed\nChain\tShimano CN-M8100, 12-speed\nPedals\tNot included\n\nComponents\nSaddle\tBontrager Arvada, hollow chromoly rails\nSeatpost\tDrop Line, internal routing, 31.6mm (15.5: 100mm, 17.5 & 18.5: 125mm, 19.5 & 21.5: 150mm)\nHandlebar\tBontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\nStem\tBontrager Line Pro, 35mm, Knock Block, 0 degree, 50mm length\nGrips\tBontrager XR Trail Elite, alloy lock-on\nHeadset\tIntegrated, sealed cartridge bearing, 1-1/8\" top, 1.5\" bottom\nBrakeset\tShimano SLX M7120, 4-piston hydraulic disc\n\nAccessories\nBattery\tShimano STEPS BT-E8010, 504Wh\nCharger\tShimano STEPS EC-E8004, 4A\nController\tShimano STEPS E8000 display\nBike weight\tM - 22.5 kg / 49.6 lbs (with tubes)\n\n## Sizing & fit\n\n| Size | Rider Height |\n|:----:|:------------------------:|\n| S | 162 - 170 cm 5'4\" - 5'7\" |\n| M | 170 - 178 cm 5'7\" - 5'10\"|\n| L | 178 - 186 cm 5'10\" - 6'1\"|\n| XL | 186 - 196 cm 6'1\" - 6'5\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\n| Frame size | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| A — Seat tube | 40.6 | 43.2 | 47.0 | 51.0 |\n| B — Seat tube angle | 75.0° | 75.0° | 75.0° | 75.0° |\n| C — Head tube length | 9.6 | 10.6 | 11.6 | 12.6 |\n| D — Head angle | 66.5° | 66.5° | 66.5° | 66.5° |\n| E — Effective top tube | 60.4 | 62.6 | 64.8 | 66.9 |\n| F — Bottom bracket height | 33.2 | 33.2 | 33.2 | 33.2 |\n| G — Bottom bracket drop | 3.0 | 3.0 | 3.0 | 3.0 |\n| H — Chainstay length | 45.5 | 45.5 | 45.5 | 45.5 |\n| I — Offset | 4.6 | 4.6 | 4.6 | 4.6 |\n| J — Trail | 11.9 | 11.9 | 11.9 | 11.9 |\n| K — Wheelbase | 117.0 | 119.3 | 121.6 | 123.9 |\n| L — Standover | 75.9 | 75.9 | 78.6 | 78.6 |\n| M — Frame reach | 43.6 | 45.6 | 47.6 | 49.6 |\n| N — Frame stack | 60.5 | 61.5 | 62.4 | 63.4 |", + "price": 1299.99, + "tags": [ + "bicycle", + "touring bike" + ] + }, + { + "name": "Zephyr 8.8 GX Eagle AXS Gen 3", + "shortDescription": "Zephyr 8.8 GX Eagle AXS is a light and nimble full-suspension mountain bike. It's designed to handle technical terrain with ease and has a smooth and efficient ride feel. The sleek and powerful Bosch Performance Line CX motor and removable Powertube battery provide a boost to your pedaling and give you long-lasting riding time. The bike also features high-end components and advanced technology for an ultimate mountain biking experience.", + "description": "## Overview\nIt's right for you if...\nYou're an avid mountain biker looking for a high-performance e-MTB that can tackle challenging trails. You want a bike with a powerful motor, efficient suspension, and advanced technology to enhance your riding experience. You also need a bike that's reliable and durable for long-lasting use.\n\nThe tech you get\nA lightweight, full carbon frame with 150mm of rear travel and a 160mm RockShox Pike Ultimate fork with Charger 2.1 RCT3 damper, remote lockout, and DebonAir spring. A Bosch Performance Line CX motor and removable Powertube 625Wh battery that can assist up to 20mph when it's on and gives zero drag when it's off, plus an easy-to-use handlebar-mounted Bosch Purion controller. A SRAM GX Eagle AXS wireless electronic drivetrain, a RockShox Reverb Stealth dropper, and DT Swiss HX1501 Spline One wheels.\n\nThe final word\nZephyr 8.8 GX Eagle AXS is a high-performance e-MTB that's designed to handle technical terrain with ease. With a powerful Bosch motor and long-lasting battery, you can conquer challenging climbs and enjoy long rides. The bike also features high-end components and advanced technology for an ultimate mountain biking experience.\n\n## Features\nPowerful motor\n\nThe Bosch Performance Line CX motor provides a boost to your pedaling and can assist up to 20mph. It has four power modes and a walk-assist function for easy navigation on steep climbs. The motor is also reliable and durable for long-lasting use.\n\nEfficient suspension\n\nZephyr 8.8 has a 150mm of rear travel and a 160mm RockShox Pike Ultimate fork with Charger 2.1 RCT3 damper, remote lockout, and DebonAir spring. The suspension is efficient and responsive, allowing you to handle technical terrain with ease.\n\nRemovable battery\n\nThe Powertube 625Wh battery is removable for easy charging and storage. It provides long-lasting riding time and can be replaced with a spare battery for even longer rides. The battery is also durable and weather-resistant for all-season riding.\n\nAdvanced technology\n\nZephyr 8.8 is equipped with advanced technology, including a Bosch Purion controller for easy motor control, a SRAM GX Eagle AXS wireless electronic drivetrain for precise shifting, and a RockShox Reverb Stealth dropper for adjustable saddle height. The bike also has DT Swiss HX1501 Spline One wheels for reliable performance on any terrain.\n\nCarbon frame\n\nThe full carbon frame is lightweight and durable, providing a smooth and efficient ride. It's also designed with a tapered head tube, internal cable routing, and Boost148 spacing for enhanced stiffness and responsiveness.\n\n## Specs\nFrameset\nFrame\tCarbon main frame & stays, tapered head tube, internal routing, Boost148, 150mm travel\nFork\tRockShox Pike Ultimate, Charger 2.1 RCT3 damper, DebonAir spring, remote lockout, tapered steerer, Boost110, 15mm Maxle Stealth, 160mm travel\nShock\tRockShox Deluxe RT3, DebonAir spring, 205mm x 57.5mm\nMax compatible fork travel\t170mm\n\nWheels\nWheel front\tDT Swiss HX1501 Spline One, Centerlock, 30mm inner width, 110x15mm Boost\nWheel rear\tDT Swiss HX1501 Spline One, Centerlock, 30mm inner width, SRAM XD driver, 148x12mm Boost\nTire\tBontrager XR4 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.40''\nMax tire size\t29x2.60\"\n\nDrivetrain\nShifter\tSRAM GX Eagle AXS, wireless, 12 speed\nRear derailleur\tSRAM GX Eagle AXS\nCrank\tBosch Gen 4, 32T\nChainring\tSRAM X-Sync 2, 32T, direct-mount\nCassette\tSRAM PG-1275 Eagle, 10-52, 12 speed\nChain\tSRAM GX Eagle, 12 speed\n\nComponents\nSaddle\tBontrager Arvada, hollow titanium rails, 138mm width\nSeatpost\tRockShox Reverb Stealth, 31.6mm, internal routing, 150mm (S), 170mm (M/L), 200mm (XL)\nHandlebar\tBontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\nGrips\tBontrager XR Trail Elite, alloy lock-on\nStem\tBontrager Line Pro, Knock Block, 35mm, 0 degree, 50mm length\nHeadset\tIntegrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\nBrake\tSRAM Code RSC hydraulic disc, 200mm (front), 180mm (rear)\nBrake rotor\tSRAM CenterLine, centerlock, round edge, 200mm (front), 180mm (rear)\n\nAccessories\nE-bike system\tBosch Performance Line CX\nBattery\tBosch Powertube 625Wh\nCharger\tBosch 4A compact charger\nController\tBosch Purion\nTool\tBontrager multi-tool, integrated storage bag\n\nWeight\nWeight\tM - 24.08 kg / 53.07 lbs (with TLR sealant, no tubes)\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n\n## Sizing & fit\n\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 153 - 162 cm 5'0\" - 5'4\" | 67 - 74 cm 26\" - 29\" |\n| M | 161 - 172 cm 5'3\" - 5'8\" | 74 - 79 cm 29\" - 31\" |\n| L | 171 - 180 cm 5'7\" - 5'11\" | 79 - 84 cm 31\" - 33\" |\n| XL | 179 - 188 cm 5'10\" - 6'2\" | 84 - 89 cm 33\" - 35\" |\n\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 15.5 | 17.5 | 19.5 | 21.5 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.4 | 41.9 | 44.5 | 47.6 |\n| B — Seat tube angle | 76.1° | 76.1° | 76.1° | 76.1° |\n| C — Head tube length | 9.6 | 10.5 | 11.5 | 12.5 |\n| D — Head angle | 65.5° | 65.5° | 65.5° | 65.5° |\n| E — Effective top tube | 58.6 | 61.3 | 64.0 | 66.7 |\n| F — Bottom bracket height | 34.0 | 34.0 | 34.0 | 34.0 |\n| G — Bottom bracket drop | 1.0 | 1.0 | 1.0 | 1.0 |\n| H — Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\n| I — Offset | 4.6 | 4.6 | 4.6 | 4.6 |\n| J — Trail | 10.5 | 10.5 | 10.5 | 10.5 |\n| K — Wheelbase | 119.5 | 122.3 | 125.0 | 127.8 |\n| L — Standover | 72.7 | 74.7 | 77.6 | 81.0 |\n|", + "price": 1499.99, + "tags": [ + "bicycle", + "electric bike", + "city bike" + ] + }, + { + "name": "Velo 99 XR1 AXS", + "shortDescription": "Velo 99 XR1 AXS is a next-generation bike designed for fast-paced adventure seekers and speed enthusiasts. Built for high-performance racing, the bike boasts state-of-the-art technology and premium components. It is the ultimate bike for riders who want to push their limits and get their adrenaline pumping.", + "description": "## Overview\nIt's right for you if...\nYou are a passionate cyclist looking for a bike that can keep up with your speed, agility, and endurance. You are an adventurer who loves to explore new terrains and challenge yourself on the toughest courses. You want a bike that is lightweight, durable, and packed with the latest technology.\n\nThe tech you get\nA lightweight, full carbon frame with advanced aerodynamics and integrated cable routing for a clean look. A high-performance SRAM XX1 Eagle AXS wireless electronic drivetrain, featuring a 12-speed cassette and a 32T chainring. A RockShox SID Ultimate fork with a remote lockout, 120mm travel, and Charger Race Day damper. A high-end SRAM G2 Ultimate hydraulic disc brake with carbon levers. A FOX Transfer SL dropper post for quick and easy height adjustments. DT Swiss XRC 1501 carbon wheels for superior speed and handling.\n\nThe final word\nVelo 99 XR1 AXS is a premium racing bike that can help you achieve your goals and reach new heights. It is designed for speed, agility, and performance, and it is packed with the latest technology and premium components. If you are a serious cyclist who wants the best, this is the bike for you.\n\n## Features\nAerodynamic design\n\nThe Velo 99 XR1 AXS features a state-of-the-art frame design that reduces drag and improves speed. It has an aerodynamic seatpost, integrated cable routing, and a sleek, streamlined look that sets it apart from other bikes.\n\nWireless electronic drivetrain\n\nThe SRAM XX1 Eagle AXS drivetrain features a wireless electronic system that provides precise, instant shifting and unmatched efficiency. It eliminates the need for cables and makes the bike lighter and faster.\n\nHigh-performance suspension\n\nThe RockShox SID Ultimate fork and Charger Race Day damper provide 120mm of smooth, responsive suspension that can handle any terrain. The fork also has a remote lockout for quick adjustments on the fly.\n\nSuperior braking power\n\nThe SRAM G2 Ultimate hydraulic disc brake system delivers unmatched stopping power and control. It has carbon levers for a lightweight, ergonomic design and precision control.\n\nCarbon wheels\n\nThe DT Swiss XRC 1501 carbon wheels are ultra-lightweight, yet incredibly strong and durable. They provide superior speed and handling, making the bike more agile and responsive.\n\n## Specs\nFrameset\nFrame\tFull carbon frame, integrated cable routing, aerodynamic design, Boost148\nFork\tRockShox SID Ultimate, Charger Race Day damper, remote lockout, tapered steerer, Boost110, 15mm Maxle Stealth, 120mm travel\n\nWheels\nWheel front\tDT Swiss XRC 1501 carbon wheel, Boost110, 15mm thru axle\nWheel rear\tDT Swiss XRC 1501 carbon wheel, SRAM XD driver, Boost148, 12mm thru axle\nTire\tSchwalbe Racing Ray, Performance Line, Addix, 29x2.25\"\nTire part\tSchwalbe Doc Blue Professional, 500ml\nMax tire size\t29x2.3\"\n\nDrivetrain\nShifter\tSRAM Eagle AXS, wireless, 12-speed\nRear derailleur\tSRAM XX1 Eagle AXS\nCrank\tSRAM XX1 Eagle, 32T, carbon\nChainring\tSRAM X-SYNC, 32T, alloy\nCassette\tSRAM Eagle XG-1299, 10-52, 12-speed\nChain\tSRAM XX1 Eagle, 12-speed\nMax chainring size\t1x: 32T\n\nComponents\nSaddle\tBontrager Montrose Elite, carbon rails, 138mm width\nSeatpost\tFOX Transfer SL, 125mm travel, internal routing, 31.6mm\nHandlebar\tBontrager Kovee Pro, ADV Carbon, 35mm, 5mm rise, 720mm width\nGrips\tBontrager XR Endurance Elite\nStem\tBontrager Kovee Pro, 35mm, Blendr compatible, 7 degree, 60mm length\nHeadset\tIntegrated, cartridge bearing, 1-1/8\" top, 1.5\" bottom\nBrake\tSRAM G2 Ultimate hydraulic disc, carbon levers, 180mm rotors\n\nAccessories\nBike computer\tBontrager Trip 300\nTool\tBontrager Flatline Pro pedal wrench, T25 Torx\n\n\n## Sizing & fit\n\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 158 - 168 cm 5'2\" - 5'6\" | 74 - 78 cm 29\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 78 - 82 cm 31\" - 32\" |\n| L | 173 - 183 cm 5'8\" - 6'0\" | 82 - 86 cm 32\" - 34\" |\n| XL | 180 - 193 cm 5'11\" - 6'4\" | 86 - 90 cm 34\" - 35\" |\n\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 15.5 | 17.5 | 19.5 | 21.5 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.9 | 43.0 | 47.0 | 51.0 |\n| B — Seat tube angle | 74.5° | 74.5° | 74.5° | 74.5° |\n| C — Head tube length | 9.0 | 10.0 | 11.0 | 12.0 |\n| D — Head angle | 68.0° | 68.0° | 68.0° | 68.0° |\n| E — Effective top tube | 57.8 | 59.7 | 61.6 | 63.6 |\n| F — Bottom bracket height | 33.0 | 33.0 | 33.0 | 33.0 |\n| G — Bottom bracket drop | 5.0 | 5.0 | 5.0 | 5.0 |\n| H — Chainstay length | 43.0 | 43.0 | 43.0 | 43.0 |\n| I — Offset | 4.2 | 4.2 | 4.2 | 4.2 |\n| J — Trail | 9.7 | 9.7 | 9.7 | 9.7 |\n| K — Wheelbase | 112.5 | 114.5 | 116.5 | 118.6 |\n| L — Standover | 75.9 | 77.8 | 81.5 | 84.2 |\n| M — Frame reach | 41.6 | 43.4 | 45.2 | 47.1 |\n| N — Frame stack | 58.2 | 58.9 | 59.3 | 59.9 |", + "price": 1099.99, + "tags": [ + "bicycle", + "mountain bike" + ] + }, + { + "name": "AURORA 11S E-MTB", + "shortDescription": "The AURORA 11S is a powerful and stylish electric mountain bike designed to take you on thrilling off-road adventures. With its sturdy frame and premium components, this bike is built to handle any terrain. It features a high-performance motor, long-lasting battery, and advanced suspension system that guarantee a smooth and comfortable ride.", + "description": "## Overview\nIt's right for you if...\nYou want a top-of-the-line e-MTB that is both powerful and stylish. You also want a bike that can handle any terrain, from steep climbs to rocky descents. With its advanced features and premium components, the AURORA 11S is designed for serious off-road riders who demand the best.\n\nThe tech you get\nA sturdy aluminum frame with advanced suspension system that provides 120mm of travel. A 750W brushless motor that delivers up to 28mph, and a 48V/14Ah lithium-ion battery that provides up to 60 miles of range on a single charge. An advanced 11-speed Shimano drivetrain with hydraulic disc brakes for precise shifting and reliable stopping power. \n\nThe final word\nThe AURORA 11S is a top-of-the-line e-MTB that delivers exceptional performance and style. Whether you're tackling steep climbs or hitting rocky descents, this bike is built to handle any terrain with ease. With its advanced features and premium components, the AURORA 11S is the perfect choice for serious off-road riders who demand the best.\n\n## Features\nPowerful and efficient\n\nThe AURORA 11S is equipped with a high-performance 750W brushless motor that delivers up to 28mph. The motor is powered by a long-lasting 48V/14Ah lithium-ion battery that provides up to 60 miles of range on a single charge.\n\nAdvanced suspension system\n\nThe bike's advanced suspension system provides 120mm of travel, ensuring a smooth and comfortable ride on any terrain. The front suspension is a Suntour XCR32 Air fork, while the rear suspension is a KS-281 hydraulic shock absorber.\n\nPremium components\n\nThe AURORA 11S features an advanced 11-speed Shimano drivetrain with hydraulic disc brakes. The bike is also equipped with a Tektro HD-E725 hydraulic disc brake system that provides reliable stopping power.\n\nSleek and stylish design\n\nWith its sleek and stylish design, the AURORA 11S is sure to turn heads on the trail. The bike's sturdy aluminum frame is available in a range of colors, including black, blue, and red.\n\n## Specs\nFrameset\nFrame Material: Aluminum\nFrame Size: S, M, L\nFork: Suntour XCR32 Air, 120mm Travel\nShock Absorber: KS-281 Hydraulic Shock Absorber\n\nWheels\nWheel Size: 27.5 inches\nTires: Kenda K1151 Nevegal, 27.5x2.35\nRims: Alloy Double Wall\nSpokes: 32H, Stainless Steel\n\nDrivetrain\nShifters: Shimano SL-M7000\nRear Derailleur: Shimano RD-M8000\nCrankset: Prowheel 42T, Alloy Crank Arm\nCassette: Shimano CS-M7000, 11-42T\nChain: KMC X11EPT\n\nBrakes\nBrake System: Tektro HD-E725 Hydraulic Disc Brake\nBrake Rotors: 180mm Front, 160mm Rear\n\nE-bike system\nMotor: 750W Brushless\nBattery: 48V/14Ah Lithium-Ion\nCharger: 48V/3A Smart Charger\nController: Intelligent Sinusoidal Wave\n\nWeight\nWeight: 59.5 lbs\n\n## Sizing & fit\n| Size | Rider Height | Standover Height |\n|------|-------------|-----------------|\n| S | 5'2\"-5'6\" | 28.5\" |\n| M | 5'7\"-6'0\" | 29.5\" |\n| L | 6'0\"-6'4\" | 30.5\" |\n\n## Geometry\nAll measurements provided in cm.\nSizing table\n| Frame size letter | S | M | L |\n|-------------------|-----|-----|-----|\n| Wheel Size | 27.5\"| 27.5\"| 27.5\"|\n| Seat tube length | 44.5| 48.5| 52.5|\n| Head tube angle | 68° | 68° | 68° |\n| Seat tube angle | 74.5°| 74.5°| 74.5°|\n| Effective top tube | 57.5| 59.5| 61.5|\n| Head tube length | 12.0| 12.0| 13.0|\n| Chainstay length | 45.5| 45.5| 45.5|\n| Bottom bracket height | 30.0| 30.0| 30.0|\n| Wheelbase | 115.0|116.5|118.5|", + "price": 1999.99, + "tags": [ + "bicycle", + "road bike" + ] + }, + { + "name": "VeloTech V9.5 AXS Gen 3", + "shortDescription": "VeloTech V9.5 AXS is a sleek and fast carbon bike that combines high-end tech with a comfortable ride. It's designed to provide the ultimate experience for the most serious riders. The bike comes with a lightweight and powerful motor that can be activated when needed, and you get a spec filled with premium parts.", + "description": "## Overview\nIt's right for you if...\nYou want a bike that is fast, efficient, and delivers an adrenaline-filled experience. You are looking for a bike that is built with cutting-edge technology, and you want a ride that is both comfortable and exciting.\n\nThe tech you get\nA lightweight and durable full carbon frame with a fork that has 100mm of travel. The bike comes with a powerful motor that can deliver up to 20 mph of assistance. The drivetrain is a wireless electronic system that is precise and reliable. The bike is also equipped with hydraulic disc brakes, tubeless-ready wheels, and comfortable grips.\n\nThe final word\nThe VeloTech V9.5 AXS is a high-end bike that delivers an incredible experience for serious riders. It combines the latest technology with a comfortable ride, making it perfect for long rides, tough climbs, and fast descents.\n\n## Features\nFast and efficient\nThe VeloTech V9.5 AXS comes with a powerful motor that can provide up to 20 mph of assistance. The motor is lightweight and efficient, providing a boost when you need it without adding bulk. The bike's battery is removable, allowing you to ride without assistance when you don't need it.\n\nSmart software for the trail\nThe VeloTech V9.5 AXS is equipped with intelligent software that delivers a smooth and responsive ride. The software allows the motor to respond immediately as you start to pedal, delivering more power over a wider cadence range. You can also customize your user settings to suit your preferences.\n\nComfortable ride\nThe VeloTech V9.5 AXS is designed to provide a comfortable ride, even on long rides. The bike's fork has 100mm of travel, providing ample cushioning for rough terrain. The bike's grips are also designed to provide a comfortable and secure grip, even on the most challenging rides.\n\n## Specs\nFrameset\nFrame\tCarbon fiber frame with internal cable routing and Boost148\nFork\t100mm of travel with remote lockout\nShock\tN/A\n\nWheels\nWheel front\tCarbon fiber tubeless-ready wheel\nWheel rear\tCarbon fiber tubeless-ready wheel\nSkewer rear\t12mm thru-axle\nTire\tTubeless-ready tire\nTire part\tTubeless sealant\n\nDrivetrain\nShifter\tWireless electronic shifter\nRear derailleur\tWireless electronic derailleur\nCrank\tCarbon fiber crankset with chainring\nCrank arm\tCarbon fiber crank arm\nChainring\tAlloy chainring\nCassette\t12-speed cassette\nChain\t12-speed chain\n\nComponents\nSaddle\tCarbon fiber saddle\nSeatpost\tCarbon fiber seatpost\nHandlebar\tCarbon fiber handlebar\nGrips\tComfortable and secure grips\nStem\tCarbon fiber stem\nHeadset\tCarbon fiber headset\nBrake\tHydraulic disc brakes\nBrake rotor\tDisc brake rotor\n\nAccessories\nE-bike system\tPowerful motor with removable battery\nBattery\tLithium-ion battery\nCharger\tFast charging adapter\nController\tHandlebar-mounted controller\nTool\tBasic toolkit\n\nWeight\nWeight\tM - 17.5 kg / 38.5 lbs (with tubeless sealant)\n\nWeight limit\nThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing & fit\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 160 - 170 cm 5'3\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| M | 170 - 180 cm 5'7\" - 5'11\" | 79 - 84 cm 31\" - 33\" |\n| L | 180 - 190 cm 5'11\" - 6'3\" | 84 - 89 cm 33\" - 35\" |\n| XL | 190 - 200 cm 6'3\" - 6'7\" | 89 - 94 cm 35\" - 37\" |\n\n## Geometry\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 50.0 | 53.3 | 55.6 | 58.8 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.4 | 43.2 | 48.3 | 53.3 |\n| B — Seat tube angle | 72.3° | 72.6° | 72.8° | 72.8° |\n| C — Head tube length | 9.0 | 10.0 | 10.5 | 11.0 |\n| D — Head angle | 67.5° | 67.5° | 67.5° | 67.5° |\n| E — Effective top tube | 58.0 | 61.7 | 64.8 | 67.0 |\n| F — Bottom bracket height | 32.3 | 32.3 | 32.3 | 32.3 |\n| G — Bottom bracket drop | 5.0 | 5.0 | 5.0 | 5.0 |\n| H — Chainstay length | 44.7 | 44.7 | 44.7 | 44.7 |\n| I — Offset | 4.2 | 4.2 | 4.2 | 4.2 |\n| J — Trail | 10.9 | 10.9 | 10.9 | 10.9 |\n| K — Wheelbase | 112.6 | 116.5 | 119.7 | 121.9 |\n| L — Standover | 76.8 | 76.8 | 76.8 | 76.8 |\n| M — Frame reach | 40.5 | 44.0 | 47.0 | 49.0 |\n| N — Frame stack | 60.9 | 61.8 | 62.2 | 62.7 |", + "price": 1699.99, + "tags": [ + "bicycle", + "electric bike", + "city bike" + ] + }, + { + "name": "Axiom D8 E-Mountain Bike", + "shortDescription": "The Axiom D8 is an electrifying mountain bike that is built for adventure. It boasts a light aluminum frame, a powerful motor and the latest tech to tackle the toughest of terrains. The D8 provides assistance without adding bulk to the bike, giving you the flexibility to ride like a traditional mountain bike or have an extra push when you need it.", + "description": "## Overview \nIt's right for you if... \nYou're looking for an electric mountain bike that can handle a wide variety of terrain, from flowing singletrack to technical descents. You also want a bike that offers a powerful motor that provides assistance without adding bulk to the bike. The D8 is designed to take you anywhere, quickly and comfortably.\n\nThe tech you get \nA lightweight aluminum frame with 140mm of travel, a Suntour fork with hydraulic lockout, and a reliable and powerful Bafang M400 mid-motor that provides a boost up to 20 mph. The bike features a Shimano Deore drivetrain, hydraulic disc brakes, and a dropper seat post. With the latest tech on-board, the D8 is designed to take you to new heights.\n\nThe final word \nThe Axiom D8 is an outstanding electric mountain bike that is designed for adventure. It's built with the latest tech and provides the flexibility to ride like a traditional mountain bike or have an extra push when you need it. Whether you're a beginner or an experienced rider, the D8 is the perfect companion for your next adventure.\n\n## Features \nBuilt for Adventure \n\nThe D8 features a lightweight aluminum frame that is built to withstand rugged terrain. It comes equipped with 140mm of travel and a Suntour fork that can handle even the toughest of trails. With this bike, you're ready to take on anything the mountain can throw at you.\n\nPowerful Motor \n\nThe Bafang M400 mid-motor provides reliable and powerful assistance without adding bulk to the bike. You can quickly and easily switch between the different assistance levels to find the perfect balance between range and power.\n\nShimano Deore Drivetrain \n\nThe Shimano Deore drivetrain is reliable and offers smooth shifting on any terrain. You can easily adjust the gears to match your riding style and maximize your performance on the mountain.\n\nDropper Seat Post \n\nThe dropper seat post allows you to easily adjust your seat height on the fly, so you can maintain the perfect position for any terrain. With the flick of a switch, you can quickly and easily lower or raise your seat to match the terrain.\n\nHydraulic Disc Brakes \n\nThe D8 features powerful hydraulic disc brakes that offer reliable stopping power in any weather condition. You can ride with confidence knowing that you have the brakes to stop on a dime.\n\n## Specs \nFrameset \nFrame\tAluminum frame with 140mm of travel \nFork\tSuntour fork with hydraulic lockout, 140mm of travel \nShock\tN/A \nMax compatible fork travel\t140mm \n \nWheels \nWheel front\tAlloy wheel \nWheel rear\tAlloy wheel \nSkewer rear\tThru axle \nTire\t29\" x 2.35\" \nTire part\tN/A \nMax tire size\t29\" x 2.6\" \n \nDrivetrain \nShifter\tShimano Deore \nRear derailleur\tShimano Deore \nCrank\tBafang M400 \nCrank arm\tN/A \nChainring\tN/A \nCassette\tShimano Deore \nChain\tShimano Deore \nMax chainring size\tN/A \n \nComponents \nSaddle\tAxiom D8 saddle \nSeatpost\tDropper seat post \nHandlebar\tAxiom D8 handlebar \nGrips\tAxiom D8 grips \nStem\tAxiom D8 stem \nHeadset\tAxiom D8 headset \nBrake\tHydraulic disc brakes \nBrake rotor\t180mm \n\nAccessories \nE-bike system\tBafang M400 mid-motor \nBattery\tLithium-ion battery, 500Wh \nCharger\tLithium-ion charger \nController\tBafang M400 controller \nTool\tN/A \n \nWeight \nWeight\tM - 22 kg / 48.5 lbs \nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 136 kg (300 lbs). \n \n \n## Sizing & fit \n \n| Size | Rider Height | Inseam | \n|:----:|:------------------------:|:--------------------:| \n| S | 152 - 165 cm 5'0\" - 5'5\" | 70 - 76 cm 27\" - 30\" | \n| M | 165 - 178 cm 5'5\" - 5'10\" | 76 - 81 cm 30\" - 32\" | \n| L | 178 - 185 cm 5'10\" - 6'1\" | 81 - 86 cm 32\" - 34\" | \n| XL | 185 - 193 cm 6'1\" - 6'4\" | 86 - 91 cm 34\" - 36\" | \n \n \n## Geometry \n \nAll measurements provided in cm unless otherwise noted. \nSizing table \n| Frame size letter | S | M | L | XL | \n|---------------------------|-------|-------|-------|-------| \n| Actual frame size | 41.9 | 46.5 | 50.8 | 55.9 | \n| Wheel size | 29\" | 29\" | 29\" | 29\" | \n| A — Seat tube | 42.0 | 46.5 | 51.0 | 56.0 | \n| B — Seat tube angle | 74.0° | 74.0° | 74.0° | 74.0° | \n| C — Head tube length | 11.0 | 12.0 | 13.0 | 15.0 | \n| D — Head angle | 68.0° | 68.0° | 68.0° | 68.0° | \n| E — Effective top tube | 57.0 | 60.0 | 62.0 | 65.0 | \n| F — Bottom bracket height | 33.0 | 33.0 | 33.0 | 33.0 | \n| G — Bottom bracket drop | 3.0 | 3.0 | 3.0 | 3.0 | \n| H — Chainstay length | 46.0 | 46.0 | 46.0 | 46.0 | \n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | \n| J — Trail | 10.9 | 10.9 | 10.9 | 10.9 | \n| K — Wheelbase | 113.0 | 116.0 | 117.5 | 120.5 | \n| L — Standover | 73.5 | 75.5 | 76.5 | 79.5 | \n| M — Frame reach | 41.0 | 43.5 | 45.0 | 47.5 | \n| N — Frame stack | 60.5 | 61.5 | 62.5 | 64.5 |", + "price": 1399.99, + "tags": [ + "bicycle", + "electric bike", + "mountain bike" + ] + }, + { + "name": "Velocity X1", + "shortDescription": "Velocity X1 is a high-performance road bike designed for speed enthusiasts. It features a lightweight yet durable frame, aerodynamic design, and top-quality components, making it the perfect choice for those who want to take their cycling experience to the next level.", + "description": "## Overview\nIt's right for you if...\nYou're an experienced cyclist looking for a bike that can keep up with your need for speed. You want a bike that's lightweight, aerodynamic, and built to perform, whether you're training for a race or just pushing yourself to go faster.\n\nThe tech you get\nA lightweight aluminum frame with a carbon fork, Shimano Ultegra groupset with a wide range of gearing, hydraulic disc brakes, aerodynamic carbon wheels, and a vibration-absorbing handlebar with ergonomic grips.\n\nThe final word\nVelocity X1 is the ultimate road bike for speed enthusiasts. Its lightweight frame, aerodynamic design, and top-quality components make it the perfect choice for those who want to take their cycling experience to the next level.\n\n\n## Features\n\nAerodynamic design\nVelocity X1 is built with an aerodynamic design to help you go faster with less effort. It features a sleek profile, hidden cables, and a carbon fork that cuts through the wind, reducing drag and increasing speed.\n\nHydraulic disc brakes\nVelocity X1 comes equipped with hydraulic disc brakes, providing excellent stopping power in all weather conditions. They're also low maintenance, with minimal adjustments needed over time.\n\nCarbon wheels\nThe Velocity X1's aerodynamic carbon wheels provide excellent speed and responsiveness, helping you achieve your fastest times yet. They're also lightweight, reducing overall bike weight and making acceleration and handling even easier.\n\nShimano Ultegra groupset\nThe Shimano Ultegra groupset provides smooth shifting and reliable performance, ensuring you get the most out of every ride. With a wide range of gearing options, it's ideal for tackling any terrain, from steep climbs to fast descents.\n\n\n## Specifications\nFrameset\nFrame with Fork\tAluminium frame, internal cable routing, 135x9mm QR\nFork\tCarbon, hidden cable routing, 100x9mm QR\n\nWheels\nWheel front\tCarbon, 30mm deep rim, 23mm width, 100x9mm QR\nWheel rear\tCarbon, 30mm deep rim, 23mm width, 135x9mm QR\nSkewer front\t100x9mm QR\nSkewer rear\t135x9mm QR\nTire\tContinental Grand Prix 5000, 700x25mm, folding bead\nMax tire size\t700x28mm without fenders\n\nDrivetrain\nShifter\tShimano Ultegra R8020, 11 speed\nRear derailleur\tShimano Ultegra R8000, 11 speed\n*Crank\tSize: S, M\nShimano Ultegra R8000, 50/34T, 170mm length\nSize: L, XL\nShimano Ultegra R8000, 50/34T, 175mm length\nBottom bracket\tShimano BB-RS500-PB, PressFit\nCassette\tShimano Ultegra R8000, 11-30T, 11 speed\nChain\tShimano Ultegra HG701, 11 speed\nPedal\tNot included\nMax chainring size\t50/34T\n\nComponents\nSaddle\tBontrager Montrose Comp, steel rails, 138mm width\nSeatpost\tBontrager Comp, 6061 alloy, 27.2mm, 8mm offset, 330mm length\n*Handlebar\tSize: S, M, L\nBontrager Elite Aero VR-CF, alloy, 31.8mm, 93mm reach, 123mm drop, 400mm width\nSize: XL\nBontrager Elite Aero VR-CF, alloy, 31.8mm, 93mm reach, 123mm drop, 420mm width\nGrips\tBontrager Supertack Perf tape\n*Stem\tSize: S, M, L\nBontrager Elite Blendr, 31.8mm clamp, 7 degree, 90mm length\nSize: XL\nBontrager Elite Blendr, 31.8mm clamp, 7 degree, 100mm length\nBrake\tShimano Ultegra R8070 hydraulic disc, flat mount\nBrake rotor\tShimano RT800, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 8.15 kg / 17.97 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| S | 162 - 170 cm 5'4\" - 5'7\" | 74 - 78 cm 29\" - 31\" |\n| M | 170 - 178 cm 5'7\" - 5'10\" | 77 - 82 cm 30\" - 32\" |\n| L | 178 - 186 cm 5'10\" - 6'1\" | 82 - 86 cm 32\" - 34\" |\n| XL | 186 - 196 cm 6'1\" - 6'5\" | 87 - 92 cm 34\" - 36\" |\n\n\n## Geometry\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 50.0 | 52.0 | 54.0 | 56.0 |\n| B — Seat tube angle | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 13.0 | 15.0 | 17.0 | 19.0 |\n| D — Head angle | 71.0° | 72.0° | 72.0° | 72.5° |\n| E — Effective top tube | 53.7 | 55.0 | 56.5 | 58.0 |\n| F — Bottom bracket height | 27.5 | 27.5 | 27.5 | 27.5 |\n| G — Bottom bracket drop | 7.3 | 7.3 | 7.3 | 7.3 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 6.0 | 6.0 | 6.0 | 5.8 |\n| K — Wheelbase | 98.2 | 99.1 | 100.1 | 101.0 |\n| L — Standover | 75.2 | 78.2 | 81.1 | 84.1 |\n| M — Frame reach | 37.5 | 38.3 | 39.1 | 39.9 |\n| N — Frame stack | 53.3 | 55.4 | 57.4 | 59.5 |", + "price": 1799.99, + "tags": [ + "bicycle", + "touring bike" + ] + }, + { + "name": "Velocity V9", + "shortDescription": "Velocity V9 is a high-performance hybrid bike that combines speed and comfort for riders who demand the best of both worlds. The lightweight aluminum frame, along with the carbon fork and seat post, provide optimal stiffness and absorption to tackle any terrain. A 2x Shimano Deore drivetrain, hydraulic disc brakes, and 700c wheels with high-quality tires make it a versatile ride for commuters, fitness riders, and weekend adventurers alike.", + "description": "## Overview\nIt's right for you if...\nYou want a fast, versatile bike that can handle anything from commuting to weekend adventures. You value comfort as much as speed and performance. You want a reliable and durable bike that will last for years to come.\n\nThe tech you get\nA lightweight aluminum frame with a carbon fork and seat post, a 2x Shimano Deore drivetrain with a wide range of gearing, hydraulic disc brakes, and 700c wheels with high-quality tires. The Velocity V9 is designed for riders who demand both performance and comfort in one package.\n\nThe final word\nThe Velocity V9 is the perfect bike for riders who want speed and performance without sacrificing comfort. The lightweight aluminum frame and carbon components provide optimal stiffness and absorption, while the 2x Shimano Deore drivetrain and hydraulic disc brakes ensure precise shifting and stopping power. Whether you're commuting, hitting the trails, or training for your next race, the Velocity V9 has everything you need to achieve your goals.\n\n## Features\n\n2x drivetrain\nA 2x drivetrain means more versatility and a wider range of gearing options. Whether you're climbing hills or sprinting on the flats, the Velocity V9 has the perfect gear for any situation.\n\nCarbon components\nThe Velocity V9 features a carbon fork and seat post to provide optimal stiffness and absorption. This means you can ride faster and more comfortably over any terrain.\n\nHydraulic disc brakes\nHydraulic disc brakes provide unparalleled stopping power and modulation in any weather condition. You'll feel confident and in control no matter where you ride.\n\n## Specifications\nFrameset\nFrame with Fork\tAluminum frame with carbon fork and seat post, internal cable routing, fender mounts, 135x5mm ThruSkew\nFork\tCarbon fork, hidden fender mounts, flat mount disc, 5x100mm thru-skew\n\nWheels\nWheel front\tDouble wall aluminum rims, 700c, quick release hub\nWheel rear\tDouble wall aluminum rims, 700c, quick release hub\nTire\tKenda Kwick Tendril, puncture resistant, reflective sidewall, 700x32c\nMax tire size\t700x35c without fenders, 700x32c with fenders\n\nDrivetrain\nShifter\tShimano Deore, 10 speed\nFront derailleur\tShimano Deore\nRear derailleur\tShimano Deore\nCrank\tShimano Deore, 46-30T, 170mm (S/M), 175mm (L/XL)\nBottom bracket\tShimano BB52, 68mm, threaded\nCassette\tShimano Deore, 11-36T, 10 speed\nChain\tShimano HG54, 10 speed\nPedal\tWellgo alloy platform\n\nComponents\nSaddle\tVelo VL-2158, steel rails\nSeatpost\tCarbon seat post, 27.2mm\nHandlebar\tAluminum, 31.8mm clamp, 15mm rise, 680mm width\nGrips\tVelo ergonomic grips\nStem\tAluminum, 31.8mm clamp, 7 degree, 90mm length\nBrake\tShimano hydraulic disc, MT200 lever, MT200 caliper\nBrake rotor\tShimano RT56, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 11.5 kg / 25.35 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size | S | M | L | XL |\n|--------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 44.0 | 48.0 | 52.0 | 56.0 |\n| B — Seat tube angle | 74.5° | 74.0° | 73.5° | 73.0° |\n| C — Head tube length | 14.5 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 71.0° | 71.0° | 71.5° | 71.5° |\n| E — Effective top tube | 56.5 | 57.5 | 58.5 | 59.5 |\n| F — Bottom bracket height | 27.0 | 27.0 | 27.0 | 27.0 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 43.0 | 43.0 | 43.0 | 43.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 7.0 | 7.0 | 6.6 | 6.6 |\n| K — Wheelbase | 105.4 | 106.3 | 107.2 | 108.2 |\n| L — Standover | 73.2 | 77.1 | 81.2 | 85.1 |\n| M — Frame reach | 39.0 | 39.8 | 40.4 | 41.3 |\n| N — Frame stack | 57.0 | 58.5 | 60.0 | 61.5 |", + "price": 2199.99, + "tags": [ + "bicycle", + "electric bike", + "mountain bike" + ] + }, + { + "name": "Aero Pro X", + "shortDescription": "Aero Pro X is a high-end racing bike designed for serious cyclists who demand speed, agility, and superior performance. The lightweight carbon frame and fork, combined with the aerodynamic design, provide optimal stiffness and efficiency to maximize your speed. The bike features a 2x Shimano Ultegra drivetrain, hydraulic disc brakes, and 700c wheels with high-quality tires. Whether you're competing in a triathlon or climbing steep hills, Aero Pro X delivers exceptional performance and precision handling.", + "description": "## Overview\nIt's right for you if...\nYou are a competitive cyclist looking for a bike that is designed for racing. You want a bike that delivers exceptional speed, agility, and precision handling. You demand superior performance and reliability from your equipment.\n\nThe tech you get\nA lightweight carbon frame with an aerodynamic design, a carbon fork with hidden fender mounts, a 2x Shimano Ultegra drivetrain with a wide range of gearing, hydraulic disc brakes, and 700c wheels with high-quality tires. Aero Pro X is designed for serious cyclists who demand nothing but the best.\n\nThe final word\nAero Pro X is the ultimate racing bike for serious cyclists. The lightweight carbon frame and aerodynamic design deliver maximum speed and efficiency, while the 2x Shimano Ultegra drivetrain and hydraulic disc brakes ensure precise shifting and stopping power. Whether you're competing in a triathlon or a criterium race, Aero Pro X delivers the performance you need to win.\n\n## Features\n\nAerodynamic design\nThe Aero Pro X features an aerodynamic design that reduces drag and maximizes efficiency. The bike is optimized for speed and agility, so you can ride faster and farther with less effort.\n\nHydraulic disc brakes\nHydraulic disc brakes provide unrivaled stopping power and modulation in any weather condition. You'll feel confident and in control no matter where you ride.\n\nCarbon components\nThe Aero Pro X features a carbon fork with hidden fender mounts to provide optimal stiffness and absorption. This means you can ride faster and more comfortably over any terrain.\n\n## Specifications\nFrameset\nFrame with Fork\tCarbon frame with an aerodynamic design, internal cable routing, 3s chain keeper, 142x12mm thru-axle\nFork\tCarbon fork with hidden fender mounts, flat mount disc, 100x12mm thru-axle\n\nWheels\nWheel front\tDouble wall carbon rims, 700c, thru-axle hub\nWheel rear\tDouble wall carbon rims, 700c, thru-axle hub\nTire\tContinental Grand Prix 5000, folding bead, 700x25c\nMax tire size\t700x28c without fenders, 700x25c with fenders\n\nDrivetrain\nShifter\tShimano Ultegra, 11 speed\nFront derailleur\tShimano Ultegra\nRear derailleur\tShimano Ultegra\nCrank\tShimano Ultegra, 52-36T, 170mm (S), 172.5mm (M), 175mm (L/XL)\nBottom bracket\tShimano BB72, 68mm, PressFit\nCassette\tShimano Ultegra, 11-30T, 11 speed\nChain\tShimano HG701, 11 speed\nPedal\tNot included\n\nComponents\nSaddle\tBontrager Montrose Elite, carbon rails, 138mm width\nSeatpost\tCarbon seat post, 27.2mm, 20mm offset\nHandlebar\tBontrager XXX Aero, carbon, 31.8mm clamp, 75mm reach, 125mm drop\nGrips\tBontrager Supertack Perf tape\nStem\tBontrager Pro, 31.8mm clamp, 7 degree, 90mm length\nBrake\tShimano hydraulic disc, Ultegra lever, Ultegra caliper\nBrake rotor\tShimano RT800, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 8.36 kg / 18.42 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider Height |\n|:----:|:-------------------------:|\n| S | 155 - 165 cm 5'1\" - 5'5\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" |\n\n## Geometry\n| Frame size | S | M | L | XL |\n|--------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 50.6 | 52.4 | 54.3 | 56.2 |\n| B — Seat tube angle | 75.5° | 74.5° | 73.5° | 72.5° |\n| C — Head tube length | 12.0 | 14.0 | 16.0 | 18.0 |\n| D — Head angle | 72.5° | 73.0° | 73.5° | 74.0° |\n| E — Effective top tube | 53.8 | 55.4 | 57.0 | 58.6 |\n| F — Bottom bracket height | 26.5 | 26.5 | 26.5 | 26.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 6.0 | 6.0 | 6.0 | 6.0 |\n| K — Wheelbase | 97.1 | 98.7 | 100.2 | 101.8 |\n| L — Standover | 73.8 | 76.2 | 78.5 | 80.8 |\n| M — Frame reach | 38.8 | 39.5 | 40.2 | 40.9 |\n| N — Frame stack | 52.8 | 54.7 | 56.6 | 58.5 |", + "price": 1599.99, + "tags": [ + "bicycle", + "road bike" + ] + }, + { + "name": "Voltex+ Ultra Lowstep", + "shortDescription": "Voltex+ Ultra Lowstep is a high-performance electric hybrid bike designed for riders who seek speed, comfort, and reliability during their everyday rides. Equipped with a powerful and efficient Voltex Drive Pro motor and a fully-integrated 600Wh battery, this e-bike allows you to cover longer distances on a single charge. The Voltex+ Ultra Lowstep comes with premium components that prioritize comfort and safety, such as a suspension seatpost, wide and stable tires, and integrated lights.", + "description": "## Overview\n\nIt's right for you if...\nYou want an e-bike that provides a boost for faster rides and effortless usage. Durability is crucial, and you need a bike with one of the most powerful and efficient motors.\n\nThe tech you get\nA lightweight Delta Carbon Fiber frame with an ultra-lowstep design, a Voltex Drive Pro (350W, 75Nm) motor capable of maintaining speeds up to 30 mph, an extended range 600Wh battery integrated into the frame, and a Voltex Control Panel. Additionally, it features a 12-speed Shimano drivetrain, hydraulic disc brakes for optimal all-weather stopping power, a suspension seatpost, wide puncture-resistant tires for added stability, ergonomic grips, a kickstand, lights, and a cargo rack.\n\nThe final word\nThis bike offers enhanced enjoyment and ease of use on long commutes, leisure rides, and adventures. With its extended-range battery, powerful Voltex motor, user-friendly controller, and a seatpost that smooths out road vibrations, it guarantees an exceptional riding experience.\n\n## Features\n\nUltra-fast assistance\n\nExperience speeds up to 30 mph with the cutting-edge Voltex Drive Pro motor, allowing you to breeze through errands, commutes, and joyrides.\n\n## Specs\n\nFrameset\n- Frame: Delta Carbon Fiber, Removable Integrated Battery (RIB), sleek welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: Voltex Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: Voltex Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: Voltex E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore XT M8100, 12-speed\n- Rear derailleur: Shimano Deore XT M8100, long cage\n- Crank: Voltex alloy, 170mm length\n- Chainring: FSA, 44T, aluminum with guard\n- Cassette: Shimano Deore XT M8100, 10-51, 12-speed\n- Chain: KMC E12 Turbo\n- Pedal: Voltex Urban pedals\n\nComponents\n- Saddle: Voltex Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar: Voltex alloy, 31.8mm, comfort sweep, 620mm width (XS, S, M), 660mm width (L)\n- Grips: Voltex Satellite Elite, alloy lock-on\n- Stem: Voltex alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length (XS, S), 105mm length (M, L)\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT520 hydraulic disc\n- Brake rotor: Shimano RT56, 6-bolt, 180mm (XS, S, M, L), 160mm (XS, S, M, L)\n\nAccessories\n- Battery: Voltex PowerTube 600Wh\n- Charger: Voltex compact 2A, 100-240V\n- Computer: Voltex Control Panel\n- Motor: Voltex Drive Pro, 75Nm, 30mph\n- Light: Voltex Solo for e-bike, taillight (XS, S, M, L), Voltex MR8, 180 lumen, 60 lux, LED, headlight (XS, S, M, L)\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: Voltex-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender: Voltex wide (XS, S, M, L), Voltex plastic (XS, S, M, L)\n\nWeight\n- Weight: M - 20.50 kg / 45.19 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 330 pounds (150 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 38.0 | 43.0 | 48.0 | 53.0 |\n| B — Seat tube angle | 70.5° | 70.5° | 70.5° | 70.5° |\n| C — Head tube length | 15.0 | 15.0 | 17.0 | 19.0 |\n| D — Head angle | 69.2° | 69.2° | 69.2° | 69.2° |\n| E — Effective top tube | 57.2 | 57.7 | 58.8 | 60.0 |\n| F — Bottom bracket height | 30.3 | 30.3 | 30.3 | 30.3 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.5 | 48.5 | 48.5 | 48.5 |\n| I — Offset | 5.0 | 5.0 | 5.0 | 5.0 |\n| J — Trail | 9.0 | 9.0 | 9.0 | 9.0 |\n| K — Wheelbase | 111.8 | 112.3 | 113.6 | 114.8 |\n| L — Standover | 42.3 | 42.3 | 42.3 | 42.3 |\n| M — Frame reach | 36.0 | 38.0 | 38.0 | 38.0 |\n| N — Frame stack | 62.0 | 62.0 | 63.9 | 65.8 |\n| Stem length | 8.0 | 8.5 | 8.5 | 10.5 |\n\nPlease note that the specifications and features listed above are subject to change and may vary based on different models and versions of the Voltex+ Ultra Lowstep bike.", + "price": 2999.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "SwiftRide Hybrid", + "shortDescription": "SwiftRide Hybrid is a versatile and efficient bike designed for riders who want a smooth and enjoyable ride on various terrains. It incorporates advanced technology and high-quality components to provide a comfortable and reliable cycling experience.", + "description": "## Overview\n\nIt's right for you if...\nYou are looking for a bike that combines the benefits of an electric bike with the versatility of a hybrid. You value durability, speed, and ease of use.\n\nThe tech you get\nThe SwiftRide Hybrid features a lightweight and durable aluminum frame, making it easy to handle and maneuver. It is equipped with a powerful electric motor that offers a speedy assist, helping you reach speeds of up to 25 mph. The bike comes with a removable and fully-integrated 500Wh battery, providing a long-range capacity for extended rides. It also includes a 10-speed Shimano drivetrain, hydraulic disc brakes for precise stopping power, wide puncture-resistant tires for stability, and integrated lights for enhanced visibility.\n\nThe final word\nThe SwiftRide Hybrid is designed for riders who want a bike that can handle daily commutes, recreational rides, and adventures. With its efficient motor, intuitive controls, and comfortable features, it offers an enjoyable and hassle-free riding experience.\n\n## Features\n\nEfficient electric assist\nExperience the thrill of effortless riding with the powerful electric motor that provides a speedy assist, making your everyday rides faster and more enjoyable.\n\n## Specs\n\nFrameset\n- Frame: Lightweight Aluminum, Removable Integrated Battery (RIB), rack & fender mounts, internal routing, 135x5mm QR\n- Fork: SwiftRide Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: SwiftRide Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: SwiftRide E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: SwiftRide City pedals\n\nComponents\n- Saddle: SwiftRide Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - SwiftRide alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - SwiftRide alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: SwiftRide Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - SwiftRide alloy quill, 31.8mm clamp, adjustable rise, 85mm length\n - Size: M, L - SwiftRide alloy quill, 31.8mm clamp, adjustable rise, 105mm length\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: SwiftRide PowerTube 500Wh\n- Charger: SwiftRide compact 2A, 100-240V\n- Computer: SwiftRide Purion\n- Motor: SwiftRide Performance Line Sport, 65Nm, 25mph\n- Light:\n - Size: XS, S, M, L - SwiftRide SOLO for e-bike, taillight\n - Size: XS, S, M, L - SwiftRide MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: SwiftRide-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - SwiftRide wide\n - Size: XS, S, M, L - SwiftRide plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm (4'10\" - 5'1\") | 69 - 73 cm (27\" - 29\") |\n| S | 155 - 165 cm (5'1\" - 5'5\") | 72 - 78 cm (28\" - 31\") |\n| M | 165 - 175 cm (5'5\" - 5'9\") | 77 - 83 cm (30\" - 33\") |\n| L | 175 - 186 cm (5'9\" - 6'1\") | 82 - 88 cm (32\" - 35\") |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", + "price": 3999.99, + "tags": [ + "bicycle", + "mountain bike", + "professional" + ] + }, + { + "name": "RoadRunner E-Speed Lowstep", + "shortDescription": "RoadRunner E-Speed Lowstep is a high-performance electric hybrid designed for riders seeking speed and excitement on their daily rides. It is equipped with a powerful and reliable ThunderBolt drive unit that offers exceptional acceleration. The bike features a fully-integrated 500Wh battery, allowing riders to cover longer distances on a single charge. With its comfortable and safe components, including a suspension seatpost, wide and stable tires, and integrated lights, the RoadRunner E-Speed Lowstep ensures a smooth and enjoyable ride.", + "description": "## Overview\n\nIt's right for you if...\nYou're looking for an e-bike that provides an extra boost to reach your destination quickly and effortlessly. You prioritize durability and want a bike with one of the fastest motors available.\n\nThe tech you get\nA lightweight and sturdy ThunderBolt aluminum frame with a lowstep geometry. The bike is equipped with a ThunderBolt Performance Sport (250W, 65Nm) drive unit capable of reaching speeds up to 28 mph. It features a long-range 500Wh battery fully integrated into the frame and a ThunderBolt controller. Additionally, the bike has a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThe RoadRunner E-Speed Lowstep is designed to provide enjoyment and ease of use on longer commutes, recreational rides, and adventurous journeys. Its long-range battery, fast ThunderBolt motor, intuitive controller, and road-smoothing suspension seatpost make it the perfect choice for riders seeking both comfort and speed.\n\n## Features\n\nSuper speedy assist\n\nThe ThunderBolt Performance Sport drive unit allows you to accelerate up to 28mph, making errands, commutes, and joyrides a breeze.\n\n## Specs\n\nFrameset\n- Frame: ThunderBolt Smooth Aluminum, Removable Integrated Battery (RIB), sleek welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: RoadRunner Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: ThunderBolt DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: ThunderBolt DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: ThunderBolt Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: ThunderBolt E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: RoadRunner City pedals\n\nComponents\n- Saddle: RoadRunner Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - RoadRunner alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - RoadRunner alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: RoadRunner Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - RoadRunner alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\n - Size: M, L - RoadRunner alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: ThunderBolt PowerTube 500Wh\n- Charger: ThunderBolt compact 2A, 100-240V\n- Computer: ThunderBolt Purion\n- Motor: ThunderBolt Performance Line Sport, 65Nm, 28mph\n- Light:\n - Size: XS, S, M, L - ThunderBolt SOLO for e-bike, taillight\n - Size: XS, S, M, L - ThunderBolt MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: MIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - RoadRunner wide\n - Size: XS, S, M, L - RoadRunner plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", + "price": 4999.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "Hyperdrive Turbo X1", + "shortDescription": "Hyperdrive Turbo X1 is a high-performance electric bike designed for riders seeking an exhilarating experience on their daily rides. It features a powerful and efficient Hyperdrive Sport drive unit and a sleek, integrated 500Wh battery for extended range. This e-bike is equipped with top-of-the-line components prioritizing comfort and safety, including a suspension seatpost, wide and stable tires, and integrated lights.", + "description": "## Overview\n\nIt's right for you if...\nYou crave the thrill of an e-bike that can accelerate rapidly, reaching high speeds effortlessly. You value durability and are looking for a bike that is equipped with one of the fastest motors available.\n\nThe tech you get\nA lightweight Hyper Alloy frame with a lowstep geometry, a Hyperdrive Sport (300W, 70Nm) drive unit capable of maintaining speeds up to 30 mph, a long-range 500Wh battery seamlessly integrated into the frame, and an intuitive Hyper Control controller. Additionally, it features a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for enhanced stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThis bike is designed for riders seeking enjoyment and convenience on longer commutes, recreational rides, and thrilling adventures. With its long-range battery, high-speed motor, user-friendly controller, and smooth-riding suspension seatpost, the Hyperdrive Turbo X1 guarantees an exceptional e-biking experience.\n\n## Features\n\nHyperboost Acceleration\nExperience adrenaline-inducing rides with the powerful Hyperdrive Sport drive unit that enables quick acceleration and effortless cruising through errands, commutes, and joyrides.\n\n## Specs\n\nFrameset\nFrame\tHyper Alloy, Removable Integrated Battery (RIB), seamless welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\nFork\tHyper Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\nMax compatible fork travel\t50mm\n\nWheels\nHub front\tFormula DC-20, alloy, 6-bolt, 5x100mm QR\nSkewer front\t132x5mm QR, ThruSkew\nHub rear\tFormula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\nSkewer rear\t153x5mm bolt-on\nRim\tHyper Connection, double-wall, 32-hole, 20 mm width, Schrader valve\nTire\tHyper E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\nMax tire size\t700x50mm with or without fenders\n\nDrivetrain\nShifter\tShimano Deore M4100, 10 speed\nRear derailleur\tShimano Deore M5120, long cage\nCrank\tProWheel alloy, 170mm length\nChainring\tFSA, 42T, steel w/guard\nCassette\tShimano Deore M4100, 11-42, 10 speed\nChain\tKMC E10\nPedal\tHyper City pedals\n\nComponents\nSaddle\tHyper Boulevard\nSeatpost\tAlloy, suspension, 31.6mm, 300mm length\n*Handlebar\tSize: XS, S, M\nHyper alloy, 31.8mm, comfort sweep, 620mm width\nSize: L\nHyper alloy, 31.8mm, comfort sweep, 660mm width\nGrips\tHyper Satellite Elite, alloy lock-on\n*Stem\tSize: XS, S\nHyper alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\nSize: M, L\nHyper alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\nHeadset\tVP sealed cartridge, 1-1/8'', threaded\nBrake\tShimano MT200 hydraulic disc\n*Brake rotor\tSize: XS, S, M, L\nShimano RT26, 6-bolt,180mm\nSize: XS, S, M, L\nShimano RT26, 6-bolt,160mm\n\nAccessories\nBattery\tHyper PowerTube 500Wh\nCharger\tHyper compact 2A, 100-240V\nComputer\tHyper Control\nMotor\tHyperdrive Sport, 70Nm, 30mph\n*Light\tSize: XS, S, M, L\nSpanninga SOLO for e-bike, taillight\nSize: XS, S, M, L\nHerrmans MR8, 180 lumen, 60 lux, LED, headlight\nKickstand\tAdjustable length rear mount alloy kickstand\nCargo rack\tMIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n*Fender\tSize: XS, S, M, L\nSKS wide\nSize: XS, S, M, L\nSKS plastic\n\nWeight\nWeight\tM - 22.30 kg / 49.17 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", + "price": 1999.99, + "tags": [ + "bicycle", + "city bike", + "professional" + ] + }, + { + "name": "Horizon+ Evo Lowstep", + "shortDescription": "The Horizon+ Evo Lowstep is a versatile electric hybrid bike designed for riders seeking a thrilling and efficient riding experience on a variety of terrains. With its powerful Bosch Performance Line Sport drive unit and integrated 500Wh battery, this e-bike enables riders to cover long distances with ease. Equipped with features prioritizing comfort and safety, such as a suspension seatpost, stable tires, and integrated lights, the Horizon+ Evo Lowstep is a reliable companion for everyday rides.", + "description": "## Overview\n\nIt's right for you if...\nYou desire the convenience and speed of an e-bike to enhance your riding, and you want an intuitive and durable bicycle. You prioritize having one of the fastest motors developed by Bosch.\n\nThe tech you get\nA lightweight Alpha Smooth Aluminum frame with a lowstep geometry, a Bosch Performance Line Sport (250W, 65Nm) drive unit capable of sustaining speeds up to 28 mph, a fully encased 500Wh battery integrated into the frame, and a Bosch Purion controller. Additionally, it features a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for improved stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThe Horizon+ Evo Lowstep offers an enjoyable and user-friendly riding experience for longer commutes, recreational rides, and adventures. It boasts an extended range battery, a high-performance Bosch motor, an intuitive controller, and a suspension seatpost for a smooth ride on various road surfaces.\n\n## Features\n\nSuper speedy assist\nExperience effortless cruising through errands, commutes, and joyrides with the new Bosch Performance Sport drive unit, allowing acceleration of up to 28 mph.\n\n## Specs\n\nFrameset\n- Frame: Alpha Platinum Aluminum, Removable Integrated Battery (RIB), smooth welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: Horizon Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Front Hub: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Front Skewer: 132x5mm QR, ThruSkew\n- Rear Hub: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Rear Skewer: 153x5mm bolt-on\n- Rim: Bontrager Connection, double-wall, 32-hole, 20mm width, Schrader valve\n- Tire: Bontrager E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10-speed\n- Rear Derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10-speed\n- Chain: KMC E10\n- Pedal: Bontrager City pedals\n\nComponents\n- Saddle: Bontrager Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - Bontrager alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - Bontrager alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: Bontrager Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - Bontrager alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\n - Size: M, L - Bontrager alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\n- Headset: VP sealed cartridge, 1-1/8\", threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: Bosch PowerTube 500Wh\n- Charger: Bosch compact 2A, 100-240V\n- Computer: Bosch Purion\n- Motor: Bosch Performance Line Sport, 65Nm, 28mph\n- Light:\n - Size: XS, S, M, L - Spanninga SOLO for e-bike, taillight\n - Size: XS, S, M, L - Herrmans MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: MIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - SKS wide\n - Size: XS, S, M, L - SKS plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", + "price": 4499.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "FastRider X1", + "shortDescription": "FastRider X1 is a high-performance e-bike designed for riders seeking speed and long-distance capabilities. Equipped with a powerful motor and a high-capacity battery, the FastRider X1 is perfect for daily commuters and e-bike enthusiasts. It boasts a sleek and functional design, making it a great alternative to car transportation. The bike also features a smartphone controller for easy navigation and entertainment options.", + "description": "## Overview\nIt's right for you if...\nYou're looking for an e-bike that offers both speed and endurance. The FastRider X1 comes with a high-performance motor and a long-lasting battery, making it ideal for long-distance rides.\n\nThe tech you get\nThe FastRider X1 features a state-of-the-art motor and a spacious battery, ensuring a fast and efficient ride.\n\nThe final word\nWith the powerful motor and long-range battery, the FastRider X1 allows you to cover more distance at higher speeds.\n\n## Features\nConnect Your Ride with the FastRider App\nDownload the FastRider app and transform your smartphone into an on-board computer. Easily dock and charge your phone with the smartphone controller, and use the thumb pad on your handlebar to make calls, listen to music, get turn-by-turn directions, and more. The app also allows you to connect with fitness and health apps, syncing your routes and ride data.\n\nGoodbye, Car. Hello, Extended Range!\nWith the option to add the Range Boost feature, you can attach a second long-range battery to your FastRider X1, doubling the distance and time between charges. This enhancement allows you to ride longer, commute farther, and take on more adventurous routes.\n\nWhat is the range?\nTo estimate the distance you can travel on a single charge, use our range calculator tool. It automatically fills in the variables for this specific bike model and assumes an average rider, but you can adjust the settings to get the most accurate estimate for your needs.\n\n## Specifications\nFrameset\n- Frame: High-performance hydroformed alloy, Removable Integrated Battery, Range Boost-compatible, internal cable routing, Motor Armour, post-mount disc, 135x5 mm QR\n- Fork: FastRider rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru axle, post mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: FastRider sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: FastRider Switch thru axle, removable lever\n- Rear Hub: FastRider alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: FastRider MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: Size: M, L, XL - 14g stainless steel, black\n- Tire: FastRider E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Size: M, L, XL - Shimano Deore M5120, long cage\n- Crank: Size: M - FastRider alloy, 170mm length / Size: L, XL - FastRider alloy, 175mm length\n- Chainring: FastRider 46T narrow/wide alloy, w/alloy guard\n- Cassette: Size: M, L, XL - Shimano Deore M4100, 11-42, 10 speed\n- Chain: Size: M, L, XL - KMC E10 / Size: M, L, XL - KMC X10e\n- Pedal: Size: M, L, XL - FastRider City pedals / Size: M, L, XL - Wellgo C157, boron axle, plastic body / Size: M, L, XL - slip-proof aluminum pedals with reflectors\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: FastRider Commuter Comp\n- Seatpost: FastRider Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Size: M - FastRider alloy, 31.8mm, 15mm rise, 600mm width / Size: L, XL - FastRider alloy, 31.8mm, 15mm rise, 660mm width\n- Grips: FastRider Satellite Elite, alloy lock-on\n- Stem: Size: M - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length / Size: L - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length / Size: XL - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\n- Headset: Size: M, L, XL - FSA IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom / Size: M, L, XL - FSA Integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\n- Battery: FastRider PowerTube 625Wh\n- Charger: FastRider standard 4A, 100-240V\n- Motor: FastRider Performance Speed, 85 Nm, 28 mph / 45 kph\n- Light: Size: M, L, XL - FastRider taillight, 50 lumens / Size: M, L, XL - FastRider headlight, 500 lumens\n- Kickstand: Size: M, L, XL - Rear mount, alloy / Size: M, L, XL - Adjustable length alloy kickstand\n- Cargo rack: FastRider integrated rear rack, aluminum\n- Fender: FastRider custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n\nWeight limit\n- This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", + "price": 5499.99, + "tags": [ + "bicycle", + "mountain bike", + "professional" + ] + }, + { + "name": "SonicRide 8S", + "shortDescription": "SonicRide 8S is a high-performance e-bike designed for riders who crave speed and long-distance capabilities. The advanced SonicDrive motor provides powerful assistance up to 28 mph, combined with a durable and long-lasting battery for extended rides. With its sleek design and thoughtful features, the SonicRide 8S is perfect for those who prefer the freedom of riding a bike over driving a car. Plus, it comes equipped with a smartphone controller for easy navigation, music, and more.", + "description": "## Overview\nIt's right for you if...\nYou want a fast and efficient e-bike that can take you long distances. The SonicRide 8S features a hydroformed aluminum frame with a concealed 625Wh battery, a high-powered SonicDrive motor, and a Smartphone Controller. It also includes essential accessories such as lights, fenders, and a rear rack.\n\nThe tech you get\nThe SonicRide 8S is equipped with the fastest SonicDrive motor, ensuring exhilarating rides at high speeds. The long-range battery is perfect for commuters and riders looking to explore new horizons.\n\nThe final word\nWith the SonicDrive motor and long-lasting battery, you can enjoy extended rides at higher speeds.\n\n## Features\n\nConnect Your Ride with SonicRide App\nDownload the SonicRide app and transform your phone into an onboard computer. Simply attach it to the Smartphone Controller for docking and charging. Use the thumb pad on your handlebar to control calls, music, directions, and more. The Bluetooth® wireless technology allows you to connect with fitness and health apps, syncing your routes and ride data.\n\nSay Goodbye to Limited Range with Range Boost!\nExperience the convenience of Range Boost, an additional long-range 500Wh battery that seamlessly attaches to your bike's down tube. This upgrade allows you to double your distance and time between charges, enabling longer commutes and more adventurous rides. Range Boost is compatible with select SonicRide electric bike models.\n\nWhat is the range?\nFor an accurate estimate of how far you can ride on a single charge, use SonicRide's range calculator. We have pre-filled the variables for this specific bike model and the average rider, but you can adjust them to obtain the most accurate estimate.\n\n## Specifications\nFrameset\n- Frame: High-performance hydroformed alloy, Removable Integrated Battery, Range Boost-compatible, internal cable routing, Motor Armour, post-mount disc, 135x5 mm QR\n- Fork: SonicRide rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru axle, post mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: SonicRide sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: SonicRide Switch thru axle, removable lever\n- Rear Hub: SonicRide alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: SonicRide MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: Size: M, L, XL - 14g stainless steel, black\n- Tire: SonicRide E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear Derailleur: Size: M, L, XL - Shimano Deore M5120, long cage\n- Crank: Size: M - SonicRide alloy, 170mm length; Size: L, XL - SonicRide alloy, 175mm length\n- Chainring: SonicRide 46T narrow/wide alloy, with alloy guard\n- Cassette: Size: M, L, XL - Shimano Deore M4100, 11-42, 10 speed\n- Chain: Size: M, L, XL - KMC E10; Size: M, L, XL - KMC X10e\n- Pedal: Size: M, L, XL - SonicRide City pedals; Size: M, L, XL - Wellgo C157, boron axle, plastic body; Size: M, L, XL - slip-proof aluminum pedals with reflectors\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: SonicRide Commuter Comp\n- Seatpost: SonicRide Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Size: M - SonicRide alloy, 31.8mm, 15mm rise, 600mm width; Size: L, XL - SonicRide alloy, 31.8mm, 15mm rise, 660mm width\n- Grips: SonicRide Satellite Elite, alloy lock-on\n- Stem: Size: M - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length; Size: L - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length; Size: XL - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\n- Headset: Size: M, L, XL - SonicRide IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom; Size: M, L, XL - SonicRide Integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\n- Battery: SonicRide PowerTube 625Wh\n- Charger: SonicRide standard 4A, 100-240V\n- Motor: SonicRide Performance Speed, 85 Nm, 28 mph / 45 kph\n- Light: Size: M, L, XL - SonicRide Lync taillight, 50 lumens; Size: M, L, XL - SonicRide Lync headlight, 500 lumens\n- Kickstand: Size: M, L, XL - Rear mount, alloy; Size: M, L, XL - Adjustable length alloy kickstand\n- Cargo rack: SonicRide integrated rear rack, aluminum\n- Fender: SonicRide custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm / 5'5\" - 5'9\" | 77 - 83 cm / 30\" - 33\" |\n| L | 175 - 186 cm / 5'9\" - 6'1\" | 82 - 88 cm / 32\" - 35\" |\n| XL | 186 - 197 cm / 6'1\" - 6'6\" | 87 - 93 cm / 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |", + "price": 5999.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "SwiftVolt Pro", + "shortDescription": "SwiftVolt Pro is a high-performance e-bike designed for riders seeking a thrilling and fast riding experience. Equipped with a powerful SwiftDrive motor that provides assistance up to 30 mph and a long-lasting battery, this bike is perfect for long-distance commuting and passionate e-bike enthusiasts. The sleek and innovative design features cater specifically to individuals who prioritize cycling over driving. Additionally, the bike is seamlessly integrated with your smartphone, allowing you to use it for navigation, music, and more.", + "description": "## Overview\nThis bike is ideal for you if:\n- You desire a sleek and modern hydroformed aluminum frame that houses a 700Wh battery.\n- You want to maintain high speeds of up to 30 mph with the assistance of the SwiftDrive motor.\n- You appreciate the convenience of using your smartphone as a controller, which can be docked and charged on the handlebar.\n\n## Features\n\nConnect with SwiftSync App\nBy downloading the SwiftSync app, your smartphone becomes an interactive on-board computer. Attach it to the handlebar-mounted controller for easy access and charging. With the thumb pad, you can make calls, listen to music, receive turn-by-turn directions, and connect with fitness and health apps to track your routes and ride data via Bluetooth® wireless technology.\n\nEnhanced Range with BoostMax\nBoostMax offers the capability to attach a second 700Wh Swift battery to the downtube of your bike, effectively doubling the distance and time between charges. This allows for extended rides, longer commutes, and more significant adventures. BoostMax is compatible with select Swift electric bike models.\n\nRange Estimation\nFor an estimate of how far you can ride on a single charge, consult the Swift range calculator. The variables are automatically populated based on this bike model and the average rider, but you can modify them to obtain the most accurate estimate.\n\n## Specifications\nFrameset\n- Frame: Lightweight hydroformed alloy, Removable Integrated Battery, BoostMax-compatible, internal cable routing, post-mount disc, 135x5 mm QR\n- Fork: SwiftVolt rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru-axle, post-mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: Swift sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: Swift Switch thru-axle, removable lever\n- Rear Hub: Swift alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: SwiftRim, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: 14g stainless steel, black\n- Tire: Swift E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear Derailleur: Shimano Deore M5120, long cage\n- Crank: Swift alloy, 170mm length\n- Chainring: Swift 46T narrow/wide alloy, w/alloy guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: Swift City pedals\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: Swift Commuter Comp\n- Seatpost: Swift Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Swift alloy, 31.8mm, 15mm rise, 600mm width (M), 660mm width (L, XL)\n- Grips: Swift Satellite Elite, alloy lock-on\n- Stem: Swift alloy, 31.8mm, Blendr compatible, 7 degree, 70mm length (M), 90mm length (L), 100mm length (XL)\n- Headset: FSA IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brakes: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake Rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max 180mm front & rear\n\nAccessories\n- Battery: Swift PowerTube 700Wh\n- Charger: Swift standard 4A, 100-240V\n- Motor: SwiftDrive, 90 Nm, 30 mph / 48 kph\n- Light: Swift Lync taillight, 50 lumens (M, L, XL), Swift Lync headlight, 500 lumens (M, L, XL)\n- Kickstand: Rear mount, alloy (M, L, XL), Adjustable length alloy kickstand (M, L, XL)\n- Cargo rack: SwiftVolt integrated rear rack, aluminum\n- Fender: Swift custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:---------------------:|:-------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", + "price": 2499.99, + "tags": [ + "bicycle", + "city bike", + "professional" + ] + }, + { + "name": "AgileEon 9X", + "shortDescription": "AgileEon 9X is a high-performance e-bike designed for riders seeking speed and endurance. Equipped with a robust motor and an extended battery life, this bike is perfect for long-distance commuters and avid e-bike enthusiasts. It boasts innovative features tailored for individuals who prioritize cycling over driving. Additionally, the bike integrates seamlessly with your smartphone, allowing you to access navigation, music, and more.", + "description": "## Overview\nIt's right for you if...\nYou crave speed and want to cover long distances efficiently. The AgileEon 9X features a sleek hydroformed aluminum frame that houses a powerful motor, along with a large-capacity battery for extended rides. It comes equipped with a 10-speed drivetrain, front and rear lighting, fenders, and a rear rack.\n\nThe tech you get\nDesigned for those constantly on the move, this bike includes a state-of-the-art motor and a high-capacity battery, making it an excellent choice for lengthy commutes.\n\nThe final word\nWith the AgileEon 9X, you can push your boundaries and explore new horizons thanks to its powerful motor and long-lasting battery.\n\n## Features\n\nConnect Your Ride with RideMate App\nMake use of the RideMate app to transform your smartphone into an onboard computer. Simply attach it to the RideMate controller to dock and charge, then utilize the thumb pad on your handlebar to make calls, listen to music, receive turn-by-turn directions, and more. The bike also supports Bluetooth® wireless technology, enabling seamless connectivity with fitness and health apps for route syncing and ride data.\n\nGoodbye, car. Hello, Extended Range!\nEnhance your riding experience with the Extended Range option, which allows for the attachment of an additional high-capacity 500Wh battery to your bike's downtube. This doubles the distance and time between charges, enabling longer rides, extended commutes, and more significant adventures. The Extended Range feature is compatible with select AgileEon electric bike models.\n\nWhat is the range?\nTo determine how far you can ride on a single charge, you can utilize the range calculator provided by AgileEon. We have pre-filled the variables for this specific model and an average rider, but adjustments can be made for a more accurate estimation.\n\n## Specifications\nFrameset\nFrame: High-performance hydroformed alloy, Removable Integrated Battery, Extended Range-compatible, internal cable routing, Motor Armor, post-mount disc, 135x5 mm QR\nFork: AgileEon rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru-axle, post-mount disc brake\nMax compatible fork travel: 63mm\n\nWheels\nFront Hub: AgileEon sealed bearing, 32-hole 15mm alloy thru-axle\nFront Skewer: AgileEon Switch thru-axle, removable lever\nRear Hub: AgileEon alloy, sealed bearing, 6-bolt, 135x5mm QR\nRear Skewer: 148x5mm bolt-on\nRim: AgileEon MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\nSpokes:\n- Size: M, L, XL: 14g stainless steel, black\nTire: AgileEon E6 Hard-Case Lite, reflective strip, 27.5x2.40''\nMax tire size: 27.5x2.40\"\n\nDrivetrain\nShifter: Shimano Deore M4100, 10-speed\nRear derailleur:\n- Size: M, L, XL: Shimano Deore M5120, long cage\nCrank:\n- Size: M: AgileEon alloy, 170mm length\n- Size: L, XL: AgileEon alloy, 175mm length\nChainring: AgileEon 46T narrow/wide alloy, with alloy guard\nCassette:\n- Size: M, L, XL: Shimano Deore M4100, 11-42, 10-speed\nChain:\n- Size: M, L, XL: KMC E10\nPedal:\n- Size: M, L, XL: AgileEon City pedals\nMax chainring size: 1x: 48T\n\nComponents\nSaddle: AgileEon Commuter Comp\nSeatpost: AgileEon Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\nHandlebar:\n- Size: M: AgileEon alloy, 31.8mm, 15mm rise, 600mm width\n- Size: L, XL: AgileEon alloy, 31.8mm, 15mm rise, 660mm width\nGrips: AgileEon Satellite Elite, alloy lock-on\nStem:\n- Size: M: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length\n- Size: L: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length\n- Size: XL: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\nHeadset:\n- Size: M, L, XL: AgileEon IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\nBrake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\nBrake rotor: Shimano RT56, 6-bolt, 180mm\nRotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\nBattery: AgileEon PowerTube 625Wh\nCharger: AgileEon standard 4A, 100-240V\nMotor: AgileEon Performance Speed, 85 Nm, 28 mph / 45 kph\nLight:\n- Size: M, L, XL: AgileEon taillight, 50 lumens\n- Size: M, L, XL: AgileEon headlight, 500 lumens\nKickstand:\n- Size: M, L, XL: Rear mount, alloy\n- Size: M, L, XL: Adjustable length alloy kickstand\nCargo rack: AgileEon integrated rear rack, aluminum\nFender: AgileEon custom aluminum\n\nWeight\nWeight: M - 25.54 kg / 56.3 lbs\nWeight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", + "price": 3499.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "Stealth R1X Pro", + "shortDescription": "Stealth R1X Pro is a high-performance carbon road bike designed for riders who crave speed and exceptional handling. With its aerodynamic tube shaping, disc brakes, and lightweight carbon wheels, the Stealth R1X Pro offers unparalleled performance for competitive road cycling.", + "description": "## Overview\nIt's right for you if...\nYou're a competitive cyclist looking for a road bike that offers superior performance in terms of speed, handling, and aerodynamics. You want a complete package that includes lightweight carbon wheels, without the need for future upgrades.\n\nThe tech you get\nThe Stealth R1X Pro features a lightweight and aerodynamic carbon frame, an advanced carbon fork, high-performance Shimano Ultegra 11-speed drivetrain, and powerful Ultegra disc brakes. The bike also comes equipped with cutting-edge Bontrager Aeolus Elite 35 carbon wheels.\n\nThe final word\nThe Stealth R1X Pro stands out with its combination of a fast and aerodynamic frame, high-end drivetrain, and top-of-the-line carbon wheels. Whether you're racing on local roads, participating in pro stage races, or engaging in hill climbing competitions, this bike is a formidable choice that delivers an exceptional riding experience.\n\n## Features\nSleek and aerodynamic design\nThe Stealth R1X Pro's aero tube shapes maximize speed and performance, making it faster on climbs and flats alike. The bike also features a streamlined Aeolus RSL bar/stem for improved front-end aerodynamics.\n\nDesigned for all riders\nThe Stealth R1X Pro is designed to provide an outstanding fit for riders of all genders, body types, riding styles, and abilities. It comes equipped with size-specific components to ensure a comfortable and efficient riding position for competitive riders.\n\n## Specifications\nFrameset\n- Frame: Ultralight carbon frame constructed with high-performance 500 Series ADV Carbon. It features Ride Tuned performance tube optimization, a tapered head tube, internal routing, DuoTrap S compatibility, flat mount disc brake mounts, and a 142x12mm thru axle.\n- Fork: Full carbon fork (Émonda SL) with a tapered carbon steerer, internal brake routing, flat mount disc brake mounts, and a 12x100mm thru axle.\n- Frame fit: H1.5 Race geometry.\n\nWheels\n- Front wheel: Bontrager Aeolus Elite 35 carbon wheel with a 35mm rim depth, ADV Carbon construction, Tubeless Ready compatibility, and a 100x12mm thru axle.\n- Rear wheel: Bontrager Aeolus Elite 35 carbon wheel with a 35mm rim depth, ADV Carbon construction, Tubeless Ready compatibility, Shimano 11/12-speed freehub, and a 142x12mm thru axle.\n- Front skewer: Bontrager Switch thru axle with a removable lever.\n- Rear skewer: Bontrager Switch thru axle with a removable lever.\n- Tire: Bontrager R2 Hard-Case Lite with an aramid bead, 60 tpi, and a size of 700x25c.\n- Maximum tire size: 28mm.\n\nDrivetrain\n- Shifter:\n - Size 47, 50, 52: Shimano Ultegra R8025 with short-reach levers, 11-speed.\n - Size 54, 56, 58, 60, 62: Shimano Ultegra R8020, 11-speed.\n- Front derailleur: Shimano Ultegra R8000, braze-on.\n- Rear derailleur: Shimano Ultegra R8000, short cage, with a maximum cog size of 30T.\n- Crank:\n - Size 47: Shimano Ultegra R8000 with 52/36 chainrings and a 165mm length.\n - Size 50, 52: Shimano Ultegra R8000 with 52/36 chainrings and a 170mm length.\n - Size 54, 56, 58: Shimano Ultegra R8000 with 52/36 chainrings and a 172.5mm length.\n - Size 60, 62: Shimano Ultegra R8000 with 52/36 chainrings and a 175mm length.\n- Bottom bracket: Praxis T47 threaded bottom bracket with internal bearings.\n- Cassette: Shimano Ultegra R8000, 11-30, 11-speed.\n- Chain: Shimano Ultegra HG701, 11-speed.\n- Maximum chainring size: 1x - 50T, 2x - 53/39.\n\nComponents\n- Saddle: Bontrager Aeolus Comp with steel rails and a width of 145mm.\n- Seatpost:\n - Size 47, 50, 52, 54: Bontrager carbon seatmast cap with a 20mm offset and a short length.\n - Size 56, 58, 60, 62: Bontrager carbon seatmast cap with a 20mm offset and a tall length.\n- Handlebar:\n - Size 47, 50: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 38cm.\n - Size 52: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 40cm.\n - Size 54, 56, 58: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 42cm.\n - Size 60, 62: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 44cm.\n- Handlebar tape: Bontrager Supertack Perf tape.\n- Stem:\n - Size 47: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 70mm.\n - Size 50: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 80mm.\n - Size 52, 54: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 90mm.\n - Size 56: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 100mm.\n - Size 58, 60, 62: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 110mm.\n- Brake: Shimano Ultegra hydraulic disc brakes with flat mount calipers.\n- Brake rotor: Shimano RT800 with centerlock mounting, 160mm diameter.\n\nWeight\n- Weight: 8.03 kg (17.71 lbs) for the 56cm frame.\n- Weight limit: The bike has a maximum total weight limit (combined weight of the bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\nPlease refer to the table below for the corresponding Stealth R1X Pro frame sizes, recommended rider height range, and inseam measurements:\n\n| Size | Rider Height | Inseam |\n|:----:|:---------------------:|:--------------:|\n| 47 | 152 - 158 cm (5'0\") | 71 - 75 cm |\n| 50 | 158 - 163 cm (5'2\") | 74 - 77 cm |\n| 52 | 163 - 168 cm (5'4\") | 76 - 79 cm |\n| 54 | 168 - 174 cm (5'6\") | 78 - 82 cm |\n| 56 | 174 - 180 cm (5'9\") | 81 - 85 cm |\n| 58 | 180 - 185 cm (5'11\") | 84 - 87 cm |\n| 60 | 185 - 190 cm (6'1\") | 86 - 90 cm |\n| 62 | 190 - 195 cm (6'3\") | 89 - 92 cm |\n\n## Geometry\nThe table below provides the geometry measurements for each frame size of the Stealth R1X Pro:\n\n| Frame size number | 47 cm | 50 cm | 52 cm | 54 cm | 56 cm | 58 cm | 60 cm | 62 cm |\n|-------------------------------|-------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 42.4 | 45.3 | 48.3 | 49.6 | 52.5 | 55.3 | 57.3 | 59.3 |\n| B — Seat tube angle | 74.6° | 74.6° | 74.2° | 73.7° | 73.3° | 73.0° | 72.8° | 72.5° |\n| C — Head tube length | 10.0 | 11.1 | 12.1 | 13.1 | 15.1 | 17.1 | 19.1 | 21.1 |\n| D — Head angle | 72.1° | 72.1° | 72.8° | 73.0° | 73.5° | 73.8° | 73.9° | 73.9° |\n| E — Effective top tube | 51.2 | 52.1 | 53.4 | 54.3 | 55.9 | 57.4 | 58.6 | 59.8 |\n| G — Bottom bracket drop | 7.2 | 7.2 | 7.2 | 7.0 | 7.0 | 6.8 | 6.8 | 6.8 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 | 41.0 | 41.1 | 41.1 | 41.2 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | 4.0 | 4.0 | 4.0 | 4.0 |\n| J — Trail | 6.8 | 6.2 | 5.8 | 5.6 | 5.8 | 5.7 | 5.6 | 5.6 |\n| K — Wheelbase | 97.2 | 97.4 | 97.7 | 98.1 | 98.3 | 99.2 | 100.1 | 101.0 |\n| L — Standover | 69.2 | 71.1 | 73.2 | 74.4 | 76.8 | 79.3 | 81.1 | 82.9 |\n| M — Frame reach | 37.3 | 37.8 | 38.3 | 38.6 | 39.1 | 39.6 | 39.9 | 40.3 |\n| N — Frame stack | 50.7 | 52.1 | 53.3 | 54.1 | 56.3 | 58.1 | 60.1 | 62.0 |\n| Saddle rail height min (short mast) | 55.5 | 58.5 | 61.5 | 64.0 | 67.0 | 69.0 | 71.0 | 73.0 |\n| Saddle rail height max (short mast) | 61.5 | 64.5 | 67.5 | 70.0 | 73.0 | 75.0 | 77.0 | 79.0 |\n| Saddle rail height min (tall mast) | 59.0 | 62.0 | 65.0 | 67.5 | 70.5 | 72.5 | 74.5 | 76.5 |\n| Saddle rail height max (tall mast) | 65.0 | 68.0 | 71.0 | 73.5 | 76.5 | 78.5 | 80.5 | 82.5 |", + "price": 2999.99, + "tags": [ + "bicycle", + "mountain bike", + "professional" + ] + }, + { + "name": "Avant SLR 6 Disc Pro", + "shortDescription": "Avant SLR 6 Disc Pro is a high-performance carbon road bike designed for riders who prioritize speed and handling. With its aero tube shaping, disc brakes, and lightweight carbon wheels, it offers the perfect balance of speed and control.", + "description": "## Overview\nIt's right for you if...\nYou're a rider who values exceptional performance on fast group rides and races, and you want a complete package that includes lightweight carbon wheels. The Avant SLR 6 Disc Pro is designed to provide the speed and aerodynamics you need to excel on any road.\n\nThe tech you get\nThe Avant SLR 6 Disc Pro features a lightweight 500 Series ADV Carbon frame and fork, Bontrager Aeolus Elite 35 carbon wheels, a full Shimano Ultegra 11-speed drivetrain, and powerful Ultegra disc brakes.\n\nThe final word\nThe standout feature of this bike is the combination of its aero frame, high-performance drivetrain, and top-quality carbon wheels. Whether you're racing, tackling challenging climbs, or participating in professional stage races, the Avant SLR 6 Disc Pro is a worthy choice that will enhance your performance.\n\n## Features\nAll-new aero design\nThe Avant SLR 6 Disc Pro features innovative aero tube shapes that provide an advantage in all riding conditions, whether it's climbing or riding on flat roads. Additionally, it is equipped with a sleek new Aeolus RSL bar/stem that enhances front-end aero performance.\n\nAwesome bikes for everyone\nThe Avant SLR 6 Disc Pro is designed with the belief that every rider, regardless of gender, body type, riding style, or ability, deserves a great bike. It is equipped with size-specific components that ensure a perfect fit for competitive riders of all genders.\n\n## Specifications\nFrameset\n- Frame: Ultralight 500 Series ADV Carbon, Ride Tuned performance tube optimization, tapered head tube, internal routing, DuoTrap S compatible, flat mount disc, 142x12mm thru axle\n- Fork: Avant SL full carbon, tapered carbon steerer, internal brake routing, flat mount disc, 12x100mm thru axle\n- Frame fit: H1.5 Race\n\nWheels\n- Front wheel: Bontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, 100x12mm thru axle\n- Rear wheel: Bontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, Shimano 11/12-speed freehub, 142x12mm thru axle\n- Front skewer: Bontrager Switch thru axle, removable lever\n- Rear skewer: Bontrager Switch thru axle, removable lever\n- Tire: Bontrager R2 Hard-Case Lite, aramid bead, 60 tpi, 700x25c\n- Max tire size: 28mm\n\nDrivetrain\n- Shifter: \n - Size 47, 50, 52: Shimano Ultegra R8025, short-reach lever, 11-speed\n - Size 54, 56, 58, 60, 62: Shimano Ultegra R8020, 11-speed\n- Front derailleur: Shimano Ultegra R8000, braze-on\n- Rear derailleur: Shimano Ultegra R8000, short cage, 30T max cog\n- Crank: \n - Size 47: Shimano Ultegra R8000, 52/36, 165mm length\n - Size 50, 52: Shimano Ultegra R8000, 52/36, 170mm length\n - Size 54, 56, 58: Shimano Ultegra R8000, 52/36, 172.5mm length\n - Size 60, 62: Shimano Ultegra R8000, 52/36, 175mm length\n- Bottom bracket: Praxis, T47 threaded, internal bearing\n- Cassette: Shimano Ultegra R8000, 11-30, 11-speed\n- Chain: Shimano Ultegra HG701, 11-speed\n- Max chainring size: 1x: 50T, 2x: 53/39\n\nComponents\n- Saddle: Bontrager Aeolus Comp, steel rails, 145mm width\n- Seatpost: \n - Size 47, 50, 52, 54: Bontrager carbon seatmast cap, 20mm offset, short length\n - Size 56, 58, 60, 62: Bontrager carbon seatmast cap, 20mm offset, tall length\n- Handlebar: \n - Size 47, 50: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 38cm width\n - Size 52: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 40cm width\n - Size 54, 56, 58: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 42cm width\n - Size 60, 62: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 44cm width\n- Handlebar tape: Bontrager Supertack Perf tape\n- Stem: \n - Size 47: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 70mm length\n - Size 50: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 80mm length\n - Size 52, 54: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 90mm length\n - Size 56: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 100mm length\n - Size 58, 60, 62: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 110mm length\n- Brake: Shimano Ultegra hydraulic disc, flat mount\n- Brake rotor: Shimano RT800, centerlock, 160mm\n\nWeight\n- Weight: 56 - 8.03 kg / 17.71 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 47 | 152 - 158 cm 5'0\" - 5'2\" | 71 - 75 cm 28\" - 30\" |\n| 50 | 158 - 163 cm 5'2\" - 5'4\" | 74 - 77 cm 29\" - 30\" |\n| 52 | 163 - 168 cm 5'4\" - 5'6\" | 76 - 79 cm 30\" - 31\" |\n| 54 | 168 - 174 cm 5'6\" - 5'9\" | 78 - 82 cm 31\" - 32\" |\n| 56 | 174 - 180 cm 5'9\" - 5'11\" | 81 - 85 cm 32\" - 33\" |\n| 58 | 180 - 185 cm 5'11\" - 6'1\" | 84 - 87 cm 33\" - 34\" |\n| 60 | 185 - 190 cm 6'1\" - 6'3\" | 86 - 90 cm 34\" - 35\" |\n| 62 | 190 - 195 cm 6'3\" - 6'5\" | 89 - 92 cm 35\" - 36\" |\n\n## Geometry\n| Frame size number | 47 cm | 50 cm | 52 cm | 54 cm | 56 cm | 58 cm | 60 cm | 62 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 42.4 | 45.3 | 48.3 | 49.6 | 52.5 | 55.3 | 57.3 | 59.3 |\n| B — Seat tube angle | 74.6° | 74.6° | 74.2° | 73.7° | 73.3° | 73.0° | 72.8° | 72.5° |\n| C — Head tube length | 10.0 | 11.1 | 12.1 | 13.1 | 15.1 | 17.1 | 19.1 | 21.1 |\n| D — Head angle | 72.1° | 72.1° | 72.8° | 73.0° | 73.5° | 73.8° | 73.9° | 73.9° |\n| E — Effective top tube | 51.2 | 52.1 | 53.4 | 54.3 | 55.9 | 57.4 | 58.6 | 59.8 |\n| G — Bottom bracket drop | 7.2 | 7.2 | 7.2 | 7.0 | 7.0 | 6.8 | 6.8 | 6.8 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 | 41.0 | 41.1 | 41.1 | 41.2 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | 4.0 | 4.0 | 4.0 | 4.0 |\n| J — Trail | 6.8 | 6.2 | 5.8 | 5.6 | 5.8 | 5.7 | 5.6 | 5.6 |\n| K — Wheelbase | 97.2 | 97.4 | 97.7 | 98.1 | 98.3 | 99.2 | 100.1 | 101.0 |\n| L — Standover | 69.2 | 71.1 | 73.2 | 74.4 | 76.8 | 79.3 | 81.1 | 82.9 |\n| M — Frame reach | 37.3 | 37.8 | 38.3 | 38.6 | 39.1 | 39.6 | 39.9 | 40.3 |\n| N — Frame stack | 50.7 | 52.1 | 53.3 | 54.1 | 56.3 | 58.1 | 60.1 | 62.0 |\n| Saddle rail height min (w/short mast) | 55.5 | 58.5 | 61.5 | 64.0 | 67.0 | 69.0 | 71.0 | 73.0 |\n| Saddle rail height max (w/short mast) | 61.5 | 64.5 | 67.5 | 70.0 | 73.0 | 75.0 | 77.0 | 79.0 |\n| Saddle rail height min (w/tall mast) | 59.0 | 62.0 | 65.0 | 67.5 | 70.5 | 72.5 | 74.5 | 76.5 |\n| Saddle rail height max (w/tall mast) | 65.0 | 68.0 | 71.0 | 73.5 | 76.5 | 78.5 | 80.5 | 82.5 |", + "price": 999.99, + "tags": [ + "bicycle", + "city bike", + "professional" + ] + } +] diff --git a/spring-ai-core/src/test/resources/logback.xml b/spring-ai-core/src/test/resources/logback.xml index 7030fdba805..d8a39b52cdb 100644 --- a/spring-ai-core/src/test/resources/logback.xml +++ b/spring-ai-core/src/test/resources/logback.xml @@ -1,16 +1,32 @@ + + - - - %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -%kvp- %msg%n - - + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -%kvp- %msg%n + + - - - - - - + + + + + + - \ No newline at end of file + diff --git a/spring-ai-docs/pom.xml b/spring-ai-docs/pom.xml index c82441a43d3..69f8952b3cc 100644 --- a/spring-ai-docs/pom.xml +++ b/spring-ai-docs/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-docs/src/assembly/javadocs.xml b/spring-ai-docs/src/assembly/javadocs.xml index bde989c221c..709caaddc4f 100644 --- a/spring-ai-docs/src/assembly/javadocs.xml +++ b/spring-ai-docs/src/assembly/javadocs.xml @@ -1,3 +1,19 @@ + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/no.svg b/spring-ai-docs/src/main/antora/modules/ROOT/images/no.svg index 36f90f81868..256b5924f16 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/images/no.svg +++ b/spring-ai-docs/src/main/antora/modules/ROOT/images/no.svg @@ -1,22 +1,41 @@ - - + + + + cancel Created with Sketch. - - - - + + + + + + - - + + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-integration-diagram-3.svg b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-integration-diagram-3.svg index 98fb78a05a2..ab80895aac8 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-integration-diagram-3.svg +++ b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-integration-diagram-3.svg @@ -1,21 +1,37 @@ + + + + + version="1.1" + id="svg1" + width="1022" + height="239.33333" + viewBox="0 0 1022 239.33333" + sodipodi:docname="spring_ai_logo copy.svg" + inkscape:version="1.3.2 (091e20e, 2023-11-25)" + xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" + xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" + xmlns="http://www.w3.org/2000/svg" +> + + + \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc index 7069b6cedb2..6e686b2168c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc @@ -18,7 +18,7 @@ var chatClient = ChatClient.builder(chatModel) ) .build(); -String response = chatClient.prompt() +String response = this.chatClient.prompt() // Set advisor parameters at runtime .advisors(advisor -> advisor.param("chat_memory_conversation_id", "678") .param("chat_memory_response_size", 100)) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc index 6efd467649f..e41e8eb538b 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc @@ -97,7 +97,7 @@ The `RateLimit` instance can be acquired from the `GenerationMetadata`, like so: ---- RateLimit rateLimit = generationMetadata.getRateLimit(); -Long tokensRemaining = rateLimit.getTokensRemaining(); +Long tokensRemaining = this.rateLimit.getTokensRemaining(); // do something interesting with the RateLimit metadata ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/speech/openai-speech.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/speech/openai-speech.adoc index 2ed3840e85f..9022ce19d13 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/speech/openai-speech.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/speech/openai-speech.adoc @@ -95,8 +95,8 @@ OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() .withSpeed(1.0f) .build(); -SpeechPrompt speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", speechOptions); -SpeechResponse response = openAiAudioSpeechModel.call(speechPrompt); +SpeechPrompt speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", this.speechOptions); +SpeechResponse response = openAiAudioSpeechModel.call(this.speechPrompt); ---- == Manual Configuration @@ -128,7 +128,7 @@ Next, create an `OpenAiAudioSpeechModel`: ---- var openAiAudioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY")); -var openAiAudioSpeechModel = new OpenAiAudioSpeechModel(openAiAudioApi); +var openAiAudioSpeechModel = new OpenAiAudioSpeechModel(this.openAiAudioApi); var speechOptions = OpenAiAudioSpeechOptions.builder() .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) @@ -136,13 +136,13 @@ var speechOptions = OpenAiAudioSpeechOptions.builder() .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) .build(); -var speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", speechOptions); -SpeechResponse response = openAiAudioSpeechModel.call(speechPrompt); +var speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", this.speechOptions); +SpeechResponse response = this.openAiAudioSpeechModel.call(this.speechPrompt); // Accessing metadata (rate limit info) -OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata(); +OpenAiAudioSpeechResponseMetadata metadata = this.response.getMetadata(); -byte[] responseAsBytes = response.getResult().getOutput(); +byte[] responseAsBytes = this.response.getResult().getOutput(); ---- == Streaming Real-time Audio @@ -153,7 +153,7 @@ The Speech API provides support for real-time audio streaming using chunk transf ---- var openAiAudioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY")); -var openAiAudioSpeechModel = new OpenAiAudioSpeechModel(openAiAudioApi); +var openAiAudioSpeechModel = new OpenAiAudioSpeechModel(this.openAiAudioApi); OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) @@ -162,9 +162,9 @@ OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) .build(); -SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); +SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", this.speechOptions); -Flux responseStream = openAiAudioSpeechModel.stream(speechPrompt); +Flux responseStream = this.openAiAudioSpeechModel.stream(this.speechPrompt); ---- == Example Code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/azure-openai-transcriptions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/azure-openai-transcriptions.adoc index a0cc7d4d1c2..40b283e8537 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/azure-openai-transcriptions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/azure-openai-transcriptions.adoc @@ -66,10 +66,10 @@ AzureOpenAiAudioTranscriptionOptions transcriptionOptions = AzureOpenAiAudioTran .withLanguage("en") .withPrompt("Ask not this, but ask that") .withTemperature(0f) - .withResponseFormat(responseFormat) + .withResponseFormat(this.responseFormat) .build(); -AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); -AudioTranscriptionResponse response = azureOpenAiTranscriptionModel.call(transcriptionRequest); +AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, this.transcriptionOptions); +AudioTranscriptionResponse response = azureOpenAiTranscriptionModel.call(this.transcriptionRequest); ---- == Manual Configuration @@ -104,7 +104,7 @@ var openAIClient = new OpenAIClientBuilder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .buildClient(); -var azureOpenAiAudioTranscriptionModel = new AzureOpenAiAudioTranscriptionModel(openAIClient, null); +var azureOpenAiAudioTranscriptionModel = new AzureOpenAiAudioTranscriptionModel(this.openAIClient, null); var transcriptionOptions = AzureOpenAiAudioTranscriptionOptions.builder() .withResponseFormat(TranscriptResponseFormat.TEXT) @@ -113,6 +113,6 @@ var transcriptionOptions = AzureOpenAiAudioTranscriptionOptions.builder() var audioFile = new FileSystemResource("/path/to/your/resource/speech/jfk.flac"); -AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); -AudioTranscriptionResponse response = azureOpenAiAudioTranscriptionModel.call(transcriptionRequest); +AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, this.transcriptionOptions); +AudioTranscriptionResponse response = this.azureOpenAiAudioTranscriptionModel.call(this.transcriptionRequest); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/openai-transcriptions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/openai-transcriptions.adoc index 5bfde4f531c..da3119376df 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/openai-transcriptions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/openai-transcriptions.adoc @@ -94,10 +94,10 @@ OpenAiAudioTranscriptionOptions transcriptionOptions = OpenAiAudioTranscriptionO .withLanguage("en") .withPrompt("Ask not this, but ask that") .withTemperature(0f) - .withResponseFormat(responseFormat) + .withResponseFormat(this.responseFormat) .build(); -AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); -AudioTranscriptionResponse response = openAiTranscriptionModel.call(transcriptionRequest); +AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, this.transcriptionOptions); +AudioTranscriptionResponse response = openAiTranscriptionModel.call(this.transcriptionRequest); ---- == Manual Configuration @@ -129,7 +129,7 @@ Next, create a `OpenAiAudioTranscriptionModel` ---- var openAiAudioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY")); -var openAiAudioTranscriptionModel = new OpenAiAudioTranscriptionModel(openAiAudioApi); +var openAiAudioTranscriptionModel = new OpenAiAudioTranscriptionModel(this.openAiAudioApi); var transcriptionOptions = OpenAiAudioTranscriptionOptions.builder() .withResponseFormat(TranscriptResponseFormat.TEXT) @@ -138,8 +138,8 @@ var transcriptionOptions = OpenAiAudioTranscriptionOptions.builder() var audioFile = new FileSystemResource("/path/to/your/resource/speech/jfk.flac"); -AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); -AudioTranscriptionResponse response = openAiTranscriptionModel.call(transcriptionRequest); +AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, this.transcriptionOptions); +AudioTranscriptionResponse response = openAiTranscriptionModel.call(this.transcriptionRequest); ---- == Example Code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc index 7e8897a6e32..09b24e8767e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc @@ -161,9 +161,9 @@ Below is a simple code example extracted from https://github.com/spring-projects byte[] imageData = new ClassPathResource("/multimodal.test.png").getContentAsByteArray(); var userMessage = new UserMessage("Explain what do you see on this picture?", - List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); + List.of(new Media(MimeTypeUtils.IMAGE_PNG, this.imageData))); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage))); +ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage))); logger.info(response.getResult().getOutput().getContent()); ---- @@ -219,13 +219,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -261,18 +261,18 @@ Next, create a `AnthropicChatModel` and use it for text generations: ---- var anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); -var chatModel = new AnthropicChatModel(anthropicApi, +var chatModel = new AnthropicChatModel(this.anthropicApi, AnthropicChatOptions.builder() .withModel("claude-3-opus-20240229") .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -300,14 +300,14 @@ AnthropicMessage chatCompletionMessage = new AnthropicMessage( List.of(new ContentBlock("Tell me a Joke?")), Role.USER); // Sync request -ResponseEntity response = anthropicApi +ResponseEntity response = this.anthropicApi .chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), - List.of(chatCompletionMessage), null, 100, 0.8, false)); + List.of(this.chatCompletionMessage), null, 100, 0.8, false)); // Streaming request -Flux response = anthropicApi +Flux response = this.anthropicApi .chatCompletionStream(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), - List.of(chatCompletionMessage), null, 100, 0.8, true)); + List.of(this.chatCompletionMessage), null, 100, 0.8, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java[AnthropicApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc index af9dda9494a..fa3ffe5e6cb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc @@ -199,7 +199,7 @@ Below is a code example excerpted from link:https://github.com/spring-projects/s URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); String response = ChatClient.create(chatModel).prompt() .options(AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").build()) - .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) + .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, this.url)) .call() .content(); ---- @@ -231,7 +231,7 @@ String response = ChatClient.create(chatModel).prompt() .options(AzureOpenAiChatOptions.builder() .withDeploymentName("gpt-4o").build()) .user(u -> u.text("Explain what do you see on this picture?") - .media(MimeTypeUtils.IMAGE_PNG, resource)) + .media(MimeTypeUtils.IMAGE_PNG, this.resource)) .call() .content(); ---- @@ -270,13 +270,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -322,13 +322,13 @@ var openAIChatOptions = AzureOpenAiChatOptions.builder() .withMaxTokens(200) .build(); -var chatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions); +var chatModel = new AzureOpenAiChatModel(this.openAIClient, this.openAIChatOptions); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc index c6bd2bce3ac..5e88b0a24e4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc @@ -155,13 +155,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -202,7 +202,7 @@ AnthropicChatBedrockApi anthropicApi = new AnthropicChatBedrockApi( new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockAnthropicChatModel chatModel = new BedrockAnthropicChatModel(anthropicApi, +BedrockAnthropicChatModel chatModel = new BedrockAnthropicChatModel(this.anthropicApi, AnthropicChatOptions.builder() .withTemperature(0.6) .withTopK(10) @@ -211,11 +211,11 @@ BedrockAnthropicChatModel chatModel = new BedrockAnthropicChatModel(anthropicApi .withAnthropicVersion(AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -244,11 +244,11 @@ AnthropicChatRequest request = AnthropicChatRequest .build(); // Sync request -AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); +AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(this.request); // Streaming request -Flux responseStream = anthropicChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = this.anthropicChatApi.chatCompletionStream(this.request); +List responses = this.responseStream.collectList().block(); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java[AnthropicChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc index b62a580cf31..848a532b4bf 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc @@ -133,9 +133,9 @@ Below is a simple code example extracted from https://github.com/spring-projects byte[] imageData = new ClassPathResource("/test.png").getContentAsByteArray(); var userMessage = new UserMessage("Explain what do you see o this picture?", - List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); + List.of(new Media(MimeTypeUtils.IMAGE_PNG, this.imageData))); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage))); assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket"); ---- @@ -196,13 +196,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -243,7 +243,7 @@ Anthropic3ChatBedrockApi anthropicApi = new Anthropic3ChatBedrockApi( new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockAnthropic3ChatModel chatModel = new BedrockAnthropic3ChatModel(anthropicApi, +BedrockAnthropic3ChatModel chatModel = new BedrockAnthropic3ChatModel(this.anthropicApi, AnthropicChatOptions.builder() .withTemperature(0.6) .withTopK(10) @@ -252,11 +252,11 @@ BedrockAnthropic3ChatModel chatModel = new BedrockAnthropic3ChatModel(anthropicA .withAnthropicVersion(AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -281,11 +281,11 @@ AnthropicChatRequest request = AnthropicChatRequest .build(); // Sync request -AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); +AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(this.request); // Streaming request -Flux responseStream = anthropicChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = this.anthropicChatApi.chatCompletionStream(this.request); +List responses = this.responseStream.collectList().block(); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java[Anthropic3ChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc index b7eac40d52c..b3df6b5d737 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc @@ -147,13 +147,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -193,7 +193,7 @@ CohereChatBedrockApi api = new CohereChatBedrockApi(CohereChatModel.COHERE_COMMA new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockCohereChatModel chatModel = new BedrockCohereChatModel(api, +BedrockCohereChatModel chatModel = new BedrockCohereChatModel(this.api, BedrockCohereChatOptions.builder() .withTemperature(0.6) .withTopK(10) @@ -201,11 +201,11 @@ BedrockCohereChatModel chatModel = new BedrockCohereChatModel(api, .withMaxTokens(678) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -242,7 +242,7 @@ var request = CohereChatRequest .withTruncate(Truncate.NONE) .build(); -CohereChatResponse response = cohereChatApi.chatCompletion(request); +CohereChatResponse response = this.cohereChatApi.chatCompletion(this.request); var request = CohereChatRequest .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") @@ -258,8 +258,8 @@ var request = CohereChatRequest .withTruncate(Truncate.NONE) .build(); -Flux responseStream = cohereChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = this.cohereChatApi.chatCompletionStream(this.request); +List responses = this.responseStream.collectList().block(); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc index ca29f165e16..5ddad57756f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc @@ -140,7 +140,7 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } } @@ -181,13 +181,13 @@ Ai21Jurassic2ChatBedrockApi api = new Ai21Jurassic2ChatBedrockApi(Ai21Jurassic2C new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockAi21Jurassic2ChatModel chatModel = new BedrockAi21Jurassic2ChatModel(api, +BedrockAi21Jurassic2ChatModel chatModel = new BedrockAi21Jurassic2ChatModel(this.api, BedrockAi21Jurassic2ChatOptions.builder() .withTemperature(0.5) .withMaxTokens(100) .withTopP(0.9).build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -214,7 +214,7 @@ Ai21Jurassic2ChatRequest request = Ai21Jurassic2ChatRequest.builder("Hello, my n .withMaxTokens(20) .build(); -Ai21Jurassic2ChatResponse response = jurassic2ChatApi.chatCompletion(request); +Ai21Jurassic2ChatResponse response = this.jurassic2ChatApi.chatCompletion(this.request); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc index a51ca340861..445e4423a14 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc @@ -145,13 +145,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -191,17 +191,17 @@ LlamaChatBedrockApi api = new LlamaChatBedrockApi(LlamaChatModel.LLAMA2_70B_CHAT new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockLlamaChatModel chatModel = new BedrockLlamaChatModel(api, +BedrockLlamaChatModel chatModel = new BedrockLlamaChatModel(this.api, BedrockLlamaChatOptions.builder() .withTemperature(0.5) .withMaxGenLen(100) .withTopP(0.9).build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -230,11 +230,11 @@ LlamaChatRequest request = LlamaChatRequest.builder("Hello, my name is") .withMaxGenLen(20) .build(); -LlamaChatResponse response = llamaChatApi.chatCompletion(request); +LlamaChatResponse response = this.llamaChatApi.chatCompletion(this.request); // Streaming response -Flux responseStream = llamaChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = this.llamaChatApi.chatCompletionStream(this.request); +List responses = this.responseStream.collectList().block(); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java[LlamaChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc index 9f34fe50bd9..970963d45e6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc @@ -143,13 +143,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -190,18 +190,18 @@ TitanChatBedrockApi titanApi = new TitanChatBedrockApi( new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockTitanChatModel chatModel = new BedrockTitanChatModel(titanApi, +BedrockTitanChatModel chatModel = new BedrockTitanChatModel(this.titanApi, BedrockTitanChatOptions.builder() .withTemperature(0.6) .withTopP(0.8) .withMaxTokenCount(100) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -229,11 +229,11 @@ TitanChatRequest titanChatRequest = TitanChatRequest.builder("Give me the names .withStopSequences(List.of("|")) .build(); -TitanChatResponse response = titanBedrockApi.chatCompletion(titanChatRequest); +TitanChatResponse response = this.titanBedrockApi.chatCompletion(this.titanChatRequest); -Flux response = titanBedrockApi.chatCompletionStream(titanChatRequest); +Flux response = this.titanBedrockApi.chatCompletionStream(this.titanChatRequest); -List results = response.collectList().block(); +List results = this.response.collectList().block(); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java[TitanChatBedrockApi]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc index eecb5cb537a..793e8940a13 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc @@ -153,7 +153,7 @@ AnthropicChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), AnthropicChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -180,7 +180,7 @@ var promptOptions = AnthropicChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc index a078e9f9c7b..93810865ceb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc @@ -148,7 +148,7 @@ AzureOpenAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), AzureOpenAiChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -185,7 +185,7 @@ var promptOptions = AzureOpenAiChatOptions.builder() .build())) .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/minimax-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/minimax-chat-functions.adoc index 039c3f8d98c..3fbeef62250 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/minimax-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/minimax-chat-functions.adoc @@ -153,7 +153,7 @@ MiniMaxChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), MiniMaxChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -190,7 +190,7 @@ var promptOptions = MiniMaxChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc index a0bf2ed38c3..0cf4e058950 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc @@ -151,7 +151,7 @@ MistralAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, MistralAiChatOptions.builder().withFunction("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); @@ -178,7 +178,7 @@ var promptOptions = MistralAiChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(userMessage, promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/moonshot-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/moonshot-chat-functions.adoc index 23a2f9017eb..fe04e62ea0c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/moonshot-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/moonshot-chat-functions.adoc @@ -153,7 +153,7 @@ MoonshotChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), MoonshotChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -190,7 +190,7 @@ var promptOptions = MoonshotChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc index 652b6928593..0e7ba4e900f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc @@ -155,7 +155,7 @@ OllamaChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, OllamaOptions.builder().withFunction("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); @@ -191,7 +191,7 @@ var promptOptions = OllamaOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(userMessage, promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc index 851f019bc06..1468380af3a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc @@ -148,7 +148,7 @@ OpenAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, OpenAiChatOptions.builder().withFunction("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); @@ -184,7 +184,7 @@ var promptOptions = OpenAiChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(userMessage, promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. @@ -254,7 +254,7 @@ BiFunction OpenAiChatOptions options = OpenAiChatOptions.builder() .withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(weatherFunction) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(this.weatherFunction) .withName("getCurrentWeather") .withDescription("Get the weather in location") .build())) @@ -269,7 +269,7 @@ You can then use these options when making a call to the chat model: [source,java] ---- UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), options)); +ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage), options)); ---- This approach allows you to pass session-specific or user-specific information to your functions, enabling more contextual and personalized responses. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc index bf06d12578a..b4aec52ca95 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc @@ -156,7 +156,7 @@ VertexAiGeminiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), VertexAiGeminiChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -194,7 +194,7 @@ var promptOptions = VertexAiGeminiChatOptions.builder() .build())) .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc index c62d80616ac..30f425786d6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc @@ -153,7 +153,7 @@ ZhiPuAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), ZhiPuAiChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -190,7 +190,7 @@ var promptOptions = ZhiPuAiChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc index cd622ef9d2c..0550b4fec80 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc @@ -261,13 +261,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -307,14 +307,14 @@ var openAiChatOptions = OpenAiChatOptions.builder() .withTemperature(0.4) .withMaxTokens(200) .build(); -var chatModel = new OpenAiChatModel(openAiApi, openAiChatOptions); +var chatModel = new OpenAiChatModel(this.openAiApi, this.openAiChatOptions); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc index 459e5eb5ce8..3cd180140f0 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc @@ -95,7 +95,7 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } } ---- @@ -131,7 +131,7 @@ Next, create a `HuggingfaceChatModel` and use it for text generations: ---- HuggingfaceChatModel chatModel = new HuggingfaceChatModel(apiKey, url); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); System.out.println(response.getGeneration().getResult().getOutput().getContent()); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/minimax-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/minimax-chat.adoc index 160b63ede94..2cacdd0f96b 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/minimax-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/minimax-chat.adoc @@ -161,13 +161,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -203,17 +203,17 @@ Next, create a `MiniMaxChatModel` and use it for text generations: ---- var miniMaxApi = new MiniMaxApi(System.getenv("MINIMAX_API_KEY")); -var chatModel = new MiniMaxChatModel(miniMaxApi, MiniMaxChatOptions.builder() +var chatModel = new MiniMaxChatModel(this.miniMaxApi, MiniMaxChatOptions.builder() .withModel(MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux streamResponse = chatModel.stream( +Flux streamResponse = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -235,12 +235,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = miniMaxApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), 0.7f, false)); +ResponseEntity response = this.miniMaxApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), 0.7f, false)); // Streaming request -Flux streamResponse = miniMaxApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), 0.7f, true)); +Flux streamResponse = this.miniMaxApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), 0.7f, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java[MiniMaxApi.java]'s JavaDoc for further information. @@ -259,21 +259,21 @@ Here is a simple snippet how to use the web search: UserMessage userMessage = new UserMessage( "How many gold medals has the United States won in total at the 2024 Olympics?"); -List messages = new ArrayList<>(List.of(userMessage)); +List messages = new ArrayList<>(List.of(this.userMessage)); List functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool()); MiniMaxChatOptions options = MiniMaxChatOptions.builder() .withModel(MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.value) - .withTools(functionTool) + .withTools(this.functionTool) .build(); // Sync request -ChatResponse response = chatModel.call(new Prompt(messages, options)); +ChatResponse response = chatModel.call(new Prompt(this.messages, this.options)); // Streaming request -Flux streamResponse = chatModel.stream(new Prompt(messages, options)); +Flux streamResponse = chatModel.stream(new Prompt(this.messages, this.options)); ---- ==== MiniMaxApi Samples diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc index e0dd67906b7..6496f094f13 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc @@ -179,13 +179,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -221,17 +221,17 @@ Next, create a `MistralAiChatModel` and use it for text generations: ---- var mistralAiApi = new MistralAiApi(System.getenv("MISTRAL_AI_API_KEY")); -var chatModel = new MistralAiChatModel(mistralAiApi, MistralAiChatOptions.builder() +var chatModel = new MistralAiChatModel(this.mistralAiApi, MistralAiChatOptions.builder() .withModel(MistralAiApi.ChatModel.LARGE.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -252,12 +252,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = mistralAiApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), MistralAiApi.ChatModel.LARGE.getValue(), 0.8, false)); +ResponseEntity response = this.mistralAiApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MistralAiApi.ChatModel.LARGE.getValue(), 0.8, false)); // Streaming request -Flux streamResponse = mistralAiApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), MistralAiApi.ChatModel.LARGE.getValue(), 0.8, true)); +Flux streamResponse = this.mistralAiApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MistralAiApi.ChatModel.LARGE.getValue(), 0.8, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java[MistralAiApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc index 36401bd13f6..a08478f38c5 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc @@ -160,13 +160,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -202,17 +202,17 @@ Next, create a `MoonshotChatModel` and use it for text generations: ---- var moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY")); -var chatModel = new MoonshotChatModel(moonshotApi, MoonshotChatOptions.builder() +var chatModel = new MoonshotChatModel(this.moonshotApi, MoonshotChatOptions.builder() .withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux streamResponse = chatModel.stream( +Flux streamResponse = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -234,12 +234,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = moonshotApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, false)); +ResponseEntity response = this.moonshotApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, false)); // Streaming request -Flux streamResponse = moonshotApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, true)); +Flux streamResponse = this.moonshotApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java[MoonshotApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc index 406422c250c..17bc1204ae4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc @@ -240,13 +240,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index 71e30665b5e..8e915a0aec3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -251,9 +251,9 @@ Below is a straightforward code example excerpted from link:https://github.com/s var imageResource = new ClassPathResource("/multimodal.test.png"); var userMessage = new UserMessage("Explain what do you see on this picture?", - new Media(MimeTypeUtils.IMAGE_PNG, imageResource)); + new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource)); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = chatModel.call(new Prompt(this.userMessage, OllamaOptions.builder().withModel(OllamaModel.LLAVA)).build()); ---- @@ -317,13 +317,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } @@ -369,16 +369,16 @@ Next, create an `OllamaChatModel` instance and use it to send requests for text ---- var ollamaApi = new OllamaApi(); -var chatModel = new OllamaChatModel(ollamaApi, +var chatModel = new OllamaChatModel(this.ollamaApi, OllamaOptions.create() .withModel(OllamaOptions.DEFAULT_MODEL) .withTemperature(0.9)); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -414,7 +414,7 @@ var request = ChatRequest.builder("orca-mini") .withOptions(OllamaOptions.create().withTemperature(0.9)) .build(); -ChatResponse response = ollamaApi.chat(request); +ChatResponse response = this.ollamaApi.chat(this.request); // Streaming request var request2 = ChatRequest.builder("orca-mini") @@ -425,5 +425,5 @@ var request2 = ChatRequest.builder("orca-mini") .withOptions(OllamaOptions.create().withTemperature(0.9).toMap()) .build(); -Flux streamingResponse = ollamaApi.streamingChat(request2); +Flux streamingResponse = this.ollamaApi.streamingChat(this.request2); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc index 88e6a4b6d5c..8ee2e8d4f87 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc @@ -178,9 +178,9 @@ Below is a code example excerpted from link:https://github.com/spring-projects/s var imageResource = new ClassPathResource("/multimodal.test.png"); var userMessage = new UserMessage("Explain what do you see on this picture?", - new Media(MimeTypeUtils.IMAGE_PNG, imageResource)); + new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource)); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = chatModel.call(new Prompt(this.userMessage, OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build())); ---- @@ -194,7 +194,7 @@ var userMessage = new UserMessage("Explain what do you see on this picture?", new Media(MimeTypeUtils.IMAGE_PNG, "https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = chatModel.call(new Prompt(this.userMessage, OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build())); ---- @@ -258,10 +258,10 @@ String jsonSchema = """ Prompt prompt = new Prompt("how can I solve 8x + 7 = -23", OpenAiChatOptions.builder() .withModel(ChatModel.GPT_4_O_MINI) - .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema)) + .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, this.jsonSchema)) .build()); -ChatResponse response = this.openAiChatModel.call(prompt); +ChatResponse response = this.openAiChatModel.call(this.prompt); ---- NOTE: Adhere to the OpenAI link:https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[subset of the JSON Schema language] format. @@ -288,18 +288,18 @@ record MathReasoning( var outputConverter = new BeanOutputConverter<>(MathReasoning.class); -var jsonSchema = outputConverter.getJsonSchema(); +var jsonSchema = this.outputConverter.getJsonSchema(); Prompt prompt = new Prompt("how can I solve 8x + 7 = -23", OpenAiChatOptions.builder() .withModel(ChatModel.GPT_4_O_MINI) - .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema)) + .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, this.jsonSchema)) .build()); -ChatResponse response = this.openAiChatModel.call(prompt); -String content = response.getResult().getOutput().getContent(); +ChatResponse response = this.openAiChatModel.call(this.prompt); +String content = this.response.getResult().getOutput().getContent(); -MathReasoning mathReasoning = outputConverter.convert(content); +MathReasoning mathReasoning = this.outputConverter.convert(this.content); ---- NOTE: Ensure you use the `@JsonProperty(required = true,...)` annotation. @@ -353,13 +353,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -399,13 +399,13 @@ var openAiChatOptions = OpenAiChatOptions.builder() .withTemperature(0.4) .withMaxTokens(200) .build(); -var chatModel = new OpenAiChatModel(openAiApi, openAiChatOptions); +var chatModel = new OpenAiChatModel(this.openAiApi, this.openAiChatOptions); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -431,12 +431,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = openAiApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false)); +ResponseEntity response = this.openAiApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), "gpt-3.5-turbo", 0.8, false)); // Streaming request -Flux streamResponse = openAiApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true)); +Flux streamResponse = this.openAiApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), "gpt-3.5-turbo", 0.8, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java[OpenAiApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc index 43bc86c48bc..19994a4f437 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc @@ -164,13 +164,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatClient.call(message)); + return Map.of("generation", this.chatClient.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatClient.stream(prompt); + return this.chatClient.stream(prompt); } } ---- @@ -206,17 +206,17 @@ Next, create a `QianFanChatModel` and use it for text generations: ---- var qianFanApi = new QianFanApi(System.getenv("QIANFAN_API_KEY"), System.getenv("QIANFAN_SECRET_KEY")); -var chatClient = new QianFanChatModel(qianFanApi, QianFanChatOptions.builder() +var chatClient = new QianFanChatModel(this.qianFanApi, QianFanChatOptions.builder() .withModel(QianFanApi.ChatModel.ERNIE_Speed_8K.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatClient.call( +ChatResponse response = this.chatClient.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux streamResponse = chatClient.stream( +Flux streamResponse = this.chatClient.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -240,12 +240,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = qianFanApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7, false)); +ResponseEntity response = this.qianFanApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), this.systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7, false)); // Streaming request -Flux streamResponse = qianFanApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7, true)); +Flux streamResponse = this.qianFanApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), this.systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java[QianFanApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc index 9e6c238b34c..58ac4280124 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc @@ -139,9 +139,9 @@ Below is a simple code example extracted from https://github.com/spring-projects byte[] data = new ClassPathResource("/vertex-test.png").getContentAsByteArray(); var userMessage = new UserMessage("Explain what do you see on this picture?", - List.of(new Media(MimeTypeUtils.IMAGE_PNG, data))); + List.of(new Media(MimeTypeUtils.IMAGE_PNG, this.data))); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage))); +ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage))); ---- == Sample Controller @@ -177,13 +177,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -219,13 +219,13 @@ Next, create a `VertexAiGeminiChatModel` and use it for text generations: ---- VertexAI vertexApi = new VertexAI(projectId, location); -var chatModel = new VertexAiGeminiChatModel(vertexApi, +var chatModel = new VertexAiGeminiChatModel(this.vertexApi, VertexAiGeminiChatOptions.builder() .withModel(ChatModel.GEMINI_PRO_1_5_PRO) .withTemperature(0.4) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-palm2-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-palm2-chat.adoc index 7d9bcbb0a56..017f4be42a9 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-palm2-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-palm2-chat.adoc @@ -136,13 +136,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -178,12 +178,12 @@ Next, create a `VertexAiPaLm2ChatModel` and use it for text generations: ---- VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(< YOUR PALM_API_KEY>); -var chatModel = new VertexAiPaLm2ChatModel(vertexAiApi, +var chatModel = new VertexAiPaLm2ChatModel(this.vertexAiApi, VertexAiPaLm2ChatOptions.builder() .withTemperature(0.4) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -207,15 +207,15 @@ VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(< YOUR PALM_API_KEY>); // Generate var prompt = new MessagePrompt(List.of(new Message("0", "Hello, how are you?"))); -GenerateMessageRequest request = new GenerateMessageRequest(prompt); +GenerateMessageRequest request = new GenerateMessageRequest(this.prompt); -GenerateMessageResponse response = vertexAiApi.generateMessage(request); +GenerateMessageResponse response = this.vertexAiApi.generateMessage(this.request); // Embed text -Embedding embedding = vertexAiApi.embedText("Hello, how are you?"); +Embedding embedding = this.vertexAiApi.embedText("Hello, how are you?"); // Batch embedding -List embeddings = vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")); +List embeddings = this.vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc index 9c63162d952..5f0727dcf22 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc @@ -119,7 +119,7 @@ public class MyClass { Prompt prompt = new Prompt(new SystemMessage(userInput), options); - var results = chatModel.call(prompt); + var results = this.chatModel.call(prompt); var generatedText = results.getResult().getOutput().getContent(); @@ -135,7 +135,7 @@ public class MyClass { Prompt prompt = new Prompt(new SystemMessage(userInput), options); - var results = chatModel.stream(prompt).collectList().block(); // wait till the stream is resolved (completed) + var results = this.chatModel.stream(prompt).collectList().block(); // wait till the stream is resolved (completed) var generatedText = results.stream() .map(generation -> generation.getResult().getOutput().getContent()) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc index 7f417336dfe..9a178bbe6d7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc @@ -162,13 +162,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -204,17 +204,17 @@ Next, create a `ZhiPuAiChatModel` and use it for text generations: ---- var zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); -var chatModel = new ZhiPuAiChatModel(zhiPuAiApi, ZhiPuAiChatOptions.builder() +var chatModel = new ZhiPuAiChatModel(this.zhiPuAiApi, ZhiPuAiChatOptions.builder() .withModel(ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux streamResponse = chatModel.stream( +Flux streamResponse = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -236,12 +236,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = zhiPuAiApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7, false)); +ResponseEntity response = this.zhiPuAiApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7, false)); // Streaming request -Flux streamResponse = zhiPuAiApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7, true)); +Flux streamResponse = this.zhiPuAiApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java[ZhiPuAiApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index b623e01bab2..2ba0f6f62e8 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -57,11 +57,11 @@ Then, create a `ChatClient.Builder` instance programmatically for every `ChatMod ---- ChatModel myChatModel = ... // usually autowired -ChatClient.Builder builder = ChatClient.builder(myChatModel); +ChatClient.Builder builder = ChatClient.builder(this.myChatModel); // or create a ChatClient with the default builder settings: -ChatClient chatClient = ChatClient.create(myChatModel); +ChatClient chatClient = ChatClient.create(this.myChatModel); ---- == ChatClient Fluent API @@ -156,13 +156,13 @@ Flux flux = this.chatClient.prompt() Generate the filmography for a random actor. {format} """) - .param("format", converter.getFormat())) + .param("format", this.converter.getFormat())) .stream() .content(); -String content = flux.collectList().block().stream().collect(Collectors.joining()); +String content = this.flux.collectList().block().stream().collect(Collectors.joining()); -List actorFilms = converter.convert(content); +List actorFilms = this.converter.convert(this.content); ---- == call() return values @@ -224,7 +224,7 @@ class AIController { @GetMapping("/ai/simple") public Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("completion", chatClient.prompt().user(message).call().content()); + return Map.of("completion", this.chatClient.prompt().user(message).call().content()); } } ---- @@ -268,7 +268,7 @@ class AIController { @GetMapping("/ai") Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message, String voice) { return Map.of("completion", - chatClient.prompt() + this.chatClient.prompt() .system(sp -> sp.param("voice", voice)) .user(message) .call() @@ -397,7 +397,7 @@ ChatClient chatClient = ChatClient.builder(chatModel) .build(); // Update filter expression at runtime -String content = chatClient.prompt() +String content = this.chatClient.prompt() .user("Please answer my question XYZ") .advisors(a -> a.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "type == 'Spring'")) .call() diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc index f4620cd5c7a..f70999a1f65 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc @@ -197,13 +197,13 @@ var openAIClient = OpenAIClientBuilder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .buildClient(); -var embeddingModel = new AzureOpenAiEmbeddingModel(openAIClient) +var embeddingModel = new AzureOpenAiEmbeddingModel(this.openAIClient) .withDefaultOptions(AzureOpenAiEmbeddingOptions.builder() .withModel("text-embedding-ada-002") .withUser("user-6") .build()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc index de04197928f..0786bace4ad 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc @@ -172,9 +172,9 @@ var cohereEmbeddingApi =new CohereEmbeddingBedrockApi( EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper()); -var embeddingModel = new BedrockCohereEmbeddingModel(cohereEmbeddingApi); +var embeddingModel = new BedrockCohereEmbeddingModel(this.cohereEmbeddingApi); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- @@ -202,7 +202,7 @@ CohereEmbeddingRequest request = new CohereEmbeddingRequest( CohereEmbeddingRequest.InputType.search_document, CohereEmbeddingRequest.Truncate.NONE); -CohereEmbeddingResponse response = api.embedding(request); +CohereEmbeddingResponse response = this.api.embedding(this.request); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc index e694649e1f9..2e0e93adfe9 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc @@ -170,9 +170,9 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp var titanEmbeddingApi = new TitanEmbeddingBedrockApi( TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id()); -var embeddingModel = new BedrockTitanEmbeddingModel(titanEmbeddingApi); +var embeddingModel = new BedrockTitanEmbeddingModel(this.titanEmbeddingApi); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World")); // NOTE titan does not support batch embedding. ---- @@ -197,7 +197,7 @@ TitanEmbeddingRequest request = TitanEmbeddingRequest.builder() .withInputText("I like to eat apples.") .build(); -TitanEmbeddingResponse response = titanEmbedApi.embedding(request); +TitanEmbeddingResponse response = this.titanEmbedApi.embedding(this.request); ---- To embed an image you need to convert it into `base64` format: @@ -213,8 +213,8 @@ byte[] image = new DefaultResourceLoader() TitanEmbeddingRequest request = TitanEmbeddingRequest.builder() - .withInputImage(Base64.getEncoder().encodeToString(image)) + .withInputImage(Base64.getEncoder().encodeToString(this.image)) .build(); -TitanEmbeddingResponse response = titanEmbedApi.embedding(request); +TitanEmbeddingResponse response = this.titanEmbedApi.embedding(this.request); ---- \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/minimax-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/minimax-embeddings.adoc index b0a917098ba..114cc624436 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/minimax-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/minimax-embeddings.adoc @@ -183,12 +183,12 @@ Next, create an `MiniMaxEmbeddingModel` instance and use it to compute the simil ---- var miniMaxApi = new MiniMaxApi(System.getenv("MINIMAX_API_KEY")); -var embeddingModel = new MiniMaxEmbeddingModel(miniMaxApi) +var embeddingModel = new MiniMaxEmbeddingModel(this.miniMaxApi) .withDefaultOptions(MiniMaxChatOptions.build() .withModel("embo-01") .build()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/mistralai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/mistralai-embeddings.adoc index 23f2781ba4b..46e9d4953f0 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/mistralai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/mistralai-embeddings.adoc @@ -184,13 +184,13 @@ Next, create an `MistralAiEmbeddingModel` instance and use it to compute the sim ---- var mistralAiApi = new MistralAiApi(System.getenv("MISTRAL_AI_API_KEY")); -var embeddingModel = new MistralAiEmbeddingModel(mistralAiApi, +var embeddingModel = new MistralAiEmbeddingModel(this.mistralAiApi, MistralAiEmbeddingOptions.builder() .withModel("mistral-embed") .withEncodingFormat("float") .build()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/oci-genai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/oci-genai-embeddings.adoc index b8d42eee241..d9409675a7a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/oci-genai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/oci-genai-embeddings.adoc @@ -159,15 +159,15 @@ final String REGION = "us-chicago-1"; final String COMPARTMENT_ID = System.getenv("OCI_COMPARTMENT_ID"); var authProvider = new ConfigFileAuthenticationDetailsProvider( - CONFIG_FILE, PROFILE); + this.CONFIG_FILE, this.PROFILE); var aiClient = GenerativeAiInferenceClient.builder() - .region(Region.valueOf(REGION)) - .build(authProvider); + .region(Region.valueOf(this.REGION)) + .build(this.authProvider); var options = OCIEmbeddingOptions.builder() - .withModel(EMBEDDING_MODEL) - .withCompartment(COMPARTMENT_ID) + .withModel(this.EMBEDDING_MODEL) + .withCompartment(this.COMPARTMENT_ID) .withServingMode("on-demand") .build(); -var embeddingModel = new OCIEmbeddingModel(aiClient, options); -List embedding = embeddingModel.embed(new Document("How many provinces are in Canada?")); +var embeddingModel = new OCIEmbeddingModel(this.aiClient, this.options); +List embedding = this.embeddingModel.embed(new Document("How many provinces are in Canada?")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc index 726367892fc..b522a169429 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc @@ -285,12 +285,12 @@ Next, create an `OllamaEmbeddingModel` instance and use it to compute the embedd ---- var ollamaApi = new OllamaApi(); -var embeddingModel = new OllamaEmbeddingModel(ollamaApi, +var embeddingModel = new OllamaEmbeddingModel(this.ollamaApi, OllamaOptions.builder() .withModel(OllamaModel.MISTRAL.id()) .build()); -EmbeddingResponse embeddingResponse = embeddingModel.call( +EmbeddingResponse embeddingResponse = this.embeddingModel.call( new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), OllamaOptions.builder() .withModel("chroma/all-minilm-l6-v2-f32")) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc index 0f9fe553aa9..51a53bb432a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc @@ -176,7 +176,7 @@ embeddingModel.setTokenizerOptions(Map.of("padding", "true")); embeddingModel.afterPropertiesSet(); -List> embeddings = embeddingModel.embed(List.of("Hello world", "World is big")); +List> embeddings = this.embeddingModel.embed(List.of("Hello world", "World is big")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc index 33a93c9a9c4..c1669acee63 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc @@ -196,7 +196,7 @@ Next, create an `OpenAiEmbeddingModel` instance and use it to compute the simila var openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); var embeddingModel = new OpenAiEmbeddingModel( - openAiApi, + this.openAiApi, MetadataMode.EMBED, OpenAiEmbeddingOptions.builder() .withModel("text-embedding-ada-002") @@ -204,7 +204,7 @@ var embeddingModel = new OpenAiEmbeddingModel( .build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc index f3c04e0f240..7fc86f0abe5 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc @@ -155,7 +155,7 @@ PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(this.jdbc embeddingModel.afterPropertiesSet(); // initialize the jdbc template and database. -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc index 260ad63631a..31c7c982653 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc @@ -192,7 +192,7 @@ var embeddingClient = new QianFanEmbeddingModel(qianFanApi) .withModel("bge_large_en") .build()); -EmbeddingResponse embeddingResponse = embeddingClient +EmbeddingResponse embeddingResponse = this.embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc index 526d07c390a..1f9db4495b7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc @@ -127,20 +127,20 @@ VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions. .withModel(VertexAiMultimodalEmbeddingOptions.DEFAULT_MODEL_NAME) .build(); -var embeddingModel = new VertexAiMultimodalEmbeddingModel(connectionDetails, options); +var embeddingModel = new VertexAiMultimodalEmbeddingModel(this.connectionDetails, this.options); Media imageMedial = new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")); Media videoMedial = new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")); -var document = new Document("Explain what do you see on this video?", List.of(imageMedial, videoMedial), Map.of()); +var document = new Document("Explain what do you see on this video?", List.of(this.imageMedial, this.videoMedial), Map.of()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); -DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(document), +DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(this.document), EmbeddingOptions.EMPTY); -EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); +EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(this.embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(3); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc index b81e56d92fc..63365196f2e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc @@ -146,9 +146,9 @@ Next, create a `VertexAiPaLm2EmbeddingModel` and use it for text generations: ---- VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(< YOUR PALM_API_KEY>); -var embeddingModel = new VertexAiPaLm2EmbeddingModel(vertexAiApi); +var embeddingModel = new VertexAiPaLm2EmbeddingModel(this.vertexAiApi); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- @@ -169,15 +169,15 @@ VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(< YOUR PALM_API_KEY>); // Generate var prompt = new MessagePrompt(List.of(new Message("0", "Hello, how are you?"))); -GenerateMessageRequest request = new GenerateMessageRequest(prompt); +GenerateMessageRequest request = new GenerateMessageRequest(this.prompt); -GenerateMessageResponse response = vertexAiApi.generateMessage(request); +GenerateMessageResponse response = this.vertexAiApi.generateMessage(this.request); // Embed text -Embedding embedding = vertexAiApi.embedText("Hello, how are you?"); +Embedding embedding = this.vertexAiApi.embedText("Hello, how are you?"); // Batch embedding -List embeddings = vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")); +List embeddings = this.vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc index 56bdef6d5a2..839319c52e5 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc @@ -154,9 +154,9 @@ VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() .withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME) .build(); -var embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, options); +var embeddingModel = new VertexAiTextEmbeddingModel(this.connectionDetails, this.options); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc index 444dede3984..0fd6a08d1ac 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc @@ -183,12 +183,12 @@ Next, create an `ZhiPuAiEmbeddingModel` instance and use it to compute the simil ---- var zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); -var embeddingModel = new ZhiPuAiEmbeddingModel(zhiPuAiApi) +var embeddingModel = new ZhiPuAiEmbeddingModel(this.zhiPuAiApi) .withDefaultOptions(ZhiPuAiChatOptions.build() .withModel("embedding-2") .build()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc index f8daf7b093c..06c5808d7e6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc @@ -123,7 +123,7 @@ class MyJsonReader { } List loadJsonAsDocuments() { - JsonReader jsonReader = new JsonReader(resource, "description", "content"); + JsonReader jsonReader = new JsonReader(this.resource, "description", "content"); return jsonReader.get(); } } @@ -189,7 +189,7 @@ This method allows you to use a JSON Pointer to retrieve a specific part of the [source,java] ---- JsonReader jsonReader = new JsonReader(resource, "description"); -List documents = jsonReader.get("/store/books/0"); +List documents = this.jsonReader.get("/store/books/0"); ---- ==== Example JSON Structure @@ -236,7 +236,7 @@ class MyTextReader { this.resource = resource; } List loadText() { - TextReader textReader = new TextReader(resource); + TextReader textReader = new TextReader(this.resource); textReader.getCustomMetadata().put("filename", "text-source.txt"); return textReader.read(); @@ -281,7 +281,7 @@ The `TextReader` processes text content as follows: [source,java] ---- List documents = textReader.get(); -List splitDocuments = new TokenTextSplitter().apply(documents); +List splitDocuments = new TokenTextSplitter().apply(this.documents); ---- * The reader uses Spring's `Resource` abstraction, allowing it to read from various sources (classpath, file system, URL, etc.). @@ -313,7 +313,7 @@ class MyMarkdownReader { .withAdditionalMetadata("filename", "code.md") .build(); - MarkdownDocumentReader reader = new MarkdownDocumentReader(resource, config); + MarkdownDocumentReader reader = new MarkdownDocumentReader(this.resource, config); return reader.get(); } } @@ -501,7 +501,7 @@ class MyTikaDocumentReader { } List loadText() { - TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(resource); + TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(this.resource); return tikaDocumentReader.read(); } } @@ -576,7 +576,7 @@ Document doc2 = new Document("Another document with content that will be split b Map.of("source", "example2.txt")); TokenTextSplitter splitter = new TokenTextSplitter(); -List splitDocuments = splitter.apply(List.of(doc1, doc2)); +List splitDocuments = this.splitter.apply(List.of(this.doc1, this.doc2)); for (Document doc : splitDocuments) { System.out.println("Chunk: " + doc.getContent()); @@ -612,7 +612,7 @@ class MyKeywordEnricher { } List enrichDocuments(List documents) { - KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(chatModel, 5); + KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(this.chatModel, 5); return enricher.apply(documents); } } @@ -655,10 +655,10 @@ KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(chatModel, 5); Document doc = new Document("This is a document about artificial intelligence and its applications in modern technology."); -List enrichedDocs = enricher.apply(List.of(doc)); +List enrichedDocs = enricher.apply(List.of(this.doc)); -Document enrichedDoc = enrichedDocs.get(0); -String keywords = (String) enrichedDoc.getMetadata().get("excerpt_keywords"); +Document enrichedDoc = this.enrichedDocs.get(0); +String keywords = (String) this.enrichedDoc.getMetadata().get("excerpt_keywords"); System.out.println("Extracted keywords: " + keywords); ---- @@ -697,7 +697,7 @@ class MySummaryEnricher { } List enrichDocuments(List documents) { - return enricher.apply(documents); + return this.enricher.apply(documents); } } ---- @@ -757,7 +757,7 @@ SummaryMetadataEnricher enricher = new SummaryMetadataEnricher(chatModel, Document doc1 = new Document("Content of document 1"); Document doc2 = new Document("Content of document 2"); -List enrichedDocs = enricher.apply(List.of(doc1, doc2)); +List enrichedDocs = enricher.apply(List.of(this.doc1, this.doc2)); // Check the metadata of the enriched documents for (Document doc : enrichedDocs) { diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index aad70761170..db7d51693c4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -154,7 +154,7 @@ To let the model know and call your `CurrentWeather` function you need to enable ---- ChatClient chatClient = ... -ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") +ChatResponse response = this.chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") .functions("CurrentWeather") // Enable the function .call(). chatResponse(); @@ -181,7 +181,7 @@ In addition to the auto-configuration, you can register callback functions, dyna ---- ChatClient chatClient = ... -ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") +ChatResponse response = this.chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") .functions(new FunctionCallbackWrapper<>( "CurrentWeather", // name "Get the weather in location", // function description @@ -230,7 +230,7 @@ BiFunction ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionCallbackWrapper.builder(weatherFunction) + .functions(FunctionCallbackWrapper.builder(this.weatherFunction) .withName("getCurrentWeather") .withDescription("Get the weather in location") .build()) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/moderation/openai-moderation.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/moderation/openai-moderation.adoc index 82fefe85428..81c1c9120fb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/moderation/openai-moderation.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/moderation/openai-moderation.adoc @@ -81,8 +81,8 @@ OpenAiModerationOptions moderationOptions = OpenAiModerationOptions.builder() .withModel("text-moderation-latest") .build(); -ModerationPrompt moderationPrompt = new ModerationPrompt("Text to be moderated", moderationOptions); -ModerationResponse response = openAiModerationModel.call(moderationPrompt); +ModerationPrompt moderationPrompt = new ModerationPrompt("Text to be moderated", this.moderationOptions); +ModerationResponse response = openAiModerationModel.call(this.moderationPrompt); // Access the moderation results Moderation moderation = moderationResponse.getResult().getOutput(); @@ -97,7 +97,7 @@ for (ModerationResult result : moderation.getResults()) { System.out.println("Flagged: " + result.isFlagged()); // Access categories - Categories categories = result.getCategories(); + Categories categories = this.result.getCategories(); System.out.println("\nCategories:"); System.out.println("Sexual: " + categories.isSexual()); System.out.println("Hate: " + categories.isHate()); @@ -112,7 +112,7 @@ for (ModerationResult result : moderation.getResults()) { System.out.println("Violence: " + categories.isViolence()); // Access category scores - CategoryScores scores = result.getCategoryScores(); + CategoryScores scores = this.result.getCategoryScores(); System.out.println("\nCategory Scores:"); System.out.println("Sexual: " + scores.getSexual()); System.out.println("Hate: " + scores.getHate()); @@ -158,14 +158,14 @@ Next, create an OpenAiModerationModel: ---- OpenAiModerationApi openAiModerationApi = new OpenAiModerationApi(System.getenv("OPENAI_API_KEY")); -OpenAiModerationModel openAiModerationModel = new OpenAiModerationModel(openAiModerationApi); +OpenAiModerationModel openAiModerationModel = new OpenAiModerationModel(this.openAiModerationApi); OpenAiModerationOptions moderationOptions = OpenAiModerationOptions.builder() .withModel("text-moderation-latest") .build(); -ModerationPrompt moderationPrompt = new ModerationPrompt("Text to be moderated", moderationOptions); -ModerationResponse response = openAiModerationModel.call(moderationPrompt); +ModerationPrompt moderationPrompt = new ModerationPrompt("Text to be moderated", this.moderationOptions); +ModerationResponse response = this.openAiModerationModel.call(this.moderationPrompt); ---- == Example Code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc index bbcab25b3b9..8afc5586d49 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc @@ -43,9 +43,9 @@ var imageResource = new ClassPathResource("/multimodal.test.png"); var userMessage = new UserMessage( "Explain what do you see in this picture?", // content - new Media(MimeTypeUtils.IMAGE_PNG, imageResource)); // media + new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource)); // media -ChatResponse response = chatModel.call(new Prompt(userMessage)); +ChatResponse response = chatModel.call(new Prompt(this.userMessage)); ---- or with the fluent xref::api/chatclient.adoc[ChatClient] API: diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/prompt.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/prompt.adoc index 7e11009f5ab..6a41bacc587 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/prompt.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/prompt.adoc @@ -202,7 +202,7 @@ A simple example taken from the https://github.com/Azure-Samples/spring-ai-azure PromptTemplate promptTemplate = new PromptTemplate("Tell me a {adjective} joke about {topic}"); -Prompt prompt = promptTemplate.create(Map.of("adjective", adjective, "topic", topic)); +Prompt prompt = this.promptTemplate.create(Map.of("adjective", adjective, "topic", topic)); return chatModel.call(prompt).getResult(); ``` @@ -215,7 +215,7 @@ String userText = """ Write at least a sentence for each pirate. """; -Message userMessage = new UserMessage(userText); +Message userMessage = new UserMessage(this.userText); String systemText = """ You are a helpful AI assistant that helps people find information. @@ -223,12 +223,12 @@ String systemText = """ You should reply to the user's request with your name and also in the style of a {voice}. """; -SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemText); -Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); +SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemText); +Message systemMessage = this.systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); -Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); +Prompt prompt = new Prompt(List.of(this.userMessage, this.systemMessage)); -List response = chatModel.call(prompt).getResults(); +List response = chatModel.call(this.prompt).getResults(); ``` diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc index 90dd0c94436..1fc320cd346 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc @@ -71,7 +71,7 @@ The format instructions are most often appended to the end of the user input usi """; // user input with a "format" placeholder. Prompt prompt = new Prompt( new PromptTemplate( - userInputTemplate, + this.userInputTemplate, Map.of(..., "format", outputConverter.getFormat()) // replace the "format" placeholder with the converter's format. ).createMessage()); ---- @@ -124,7 +124,7 @@ or using the low-level `ChatModel` API directly: BeanOutputConverter beanOutputConverter = new BeanOutputConverter<>(ActorsFilms.class); -String format = beanOutputConverter.getFormat(); +String format = this.beanOutputConverter.getFormat(); String actor = "Tom Hanks"; @@ -134,9 +134,9 @@ String template = """ """; Generation generation = chatModel.call( - new PromptTemplate(template, Map.of("actor", actor, "format", format)).create()).getResult(); + new PromptTemplate(this.template, Map.of("actor", this.actor, "format", this.format)).create()).getResult(); -ActorsFilms actorsFilms = beanOutputConverter.convert(generation.getOutput().getContent()); +ActorsFilms actorsFilms = this.beanOutputConverter.convert(this.generation.getOutput().getContent()); ---- ==== Generic Bean Types @@ -159,17 +159,17 @@ or using the low-level `ChatModel` API directly: BeanOutputConverter> outputConverter = new BeanOutputConverter<>( new ParameterizedTypeReference>() { }); -String format = outputConverter.getFormat(); +String format = this.outputConverter.getFormat(); String template = """ Generate the filmography of 5 movies for Tom Hanks and Bill Murray. {format} """; -Prompt prompt = new PromptTemplate(template, Map.of("format", format)).create(); +Prompt prompt = new PromptTemplate(this.template, Map.of("format", this.format)).create(); -Generation generation = chatModel.call(prompt).getResult(); +Generation generation = chatModel.call(this.prompt).getResult(); -List actorsFilms = outputConverter.convert(generation.getOutput().getContent()); +List actorsFilms = this.outputConverter.convert(this.generation.getOutput().getContent()); ---- === Map Output Converter @@ -191,18 +191,18 @@ or using the low-level `ChatModel` API directly: ---- MapOutputConverter mapOutputConverter = new MapOutputConverter(); -String format = mapOutputConverter.getFormat(); +String format = this.mapOutputConverter.getFormat(); String template = """ Provide me a List of {subject} {format} """; -Prompt prompt = new PromptTemplate(template, - Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)).create(); +Prompt prompt = new PromptTemplate(this.template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", this.format)).create(); -Generation generation = chatModel.call(prompt).getResult(); +Generation generation = chatModel.call(this.prompt).getResult(); -Map result = mapOutputConverter.convert(generation.getOutput().getContent()); +Map result = this.mapOutputConverter.convert(this.generation.getOutput().getContent()); ---- === List Output Converter @@ -224,18 +224,18 @@ or using the low-level `ChatModel API` directly: ---- ListOutputConverter listOutputConverter = new ListOutputConverter(new DefaultConversionService()); -String format = listOutputConverter.getFormat(); +String format = this.listOutputConverter.getFormat(); String template = """ List five {subject} {format} """; -Prompt prompt = new PromptTemplate(template, - Map.of("subject", "ice cream flavors", "format", format)).create(); +Prompt prompt = new PromptTemplate(this.template, + Map.of("subject", "ice cream flavors", "format", this.format)).create(); -Generation generation = this.chatModel.call(prompt).getResult(); +Generation generation = this.chatModel.call(this.prompt).getResult(); -List list = listOutputConverter.convert(generation.getOutput().getContent()); +List list = this.listOutputConverter.convert(this.generation.getOutput().getContent()); ---- == Supported AI Models diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc index 387c1805563..9f320f0d12c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc @@ -167,7 +167,7 @@ Additionally, `TokenCountBatchingStrategy` provides flexibility by allowing you ---- TokenCountEstimator customEstimator = new YourCustomTokenCountEstimator(); TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy( - customEstimator, + this.customEstimator, 8000, // maxInputTokenCount 0.1, // reservePercentage Document.DEFAULT_CONTENT_FORMATTER, @@ -256,7 +256,7 @@ Later, when a user question is passed into the AI model, a similarity search is ```java String question = - List similarDocuments = store.similaritySearch(question); + List similarDocuments = store.similaritySearch(this.question); ``` Additional options can be passed into the `similaritySearch` method to define how many documents to retrieve and a threshold of the similarity search. @@ -282,7 +282,7 @@ A simple example is as follows: [source, java] ---- FilterExpressionBuilder b = new FilterExpressionBuilder(); -Expression expression = b.eq("country", "BG").build(); +Expression expression = this.b.eq("country", "BG").build(); ---- You can build up sophisticated expressions by using the following operators: diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc index aecdc6cc06a..82677abefbb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc @@ -177,7 +177,7 @@ or programmatically using the expression DSL: [source,java] ---- Filter.Expression f = new FilterExpressionBuilder() - .and(f.in("country", "UK", "NL"), f.gte("year", 2020)).build(); + .and(f.in("country", "UK", "NL"), this.f.gte("year", 2020)).build(); vectorStore.similaritySearch( SearchRequest.query("The World").withTopK(TOP_K) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc index df080dbb18d..2e1cac5af1c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc @@ -68,13 +68,13 @@ public class DemoApplication implements CommandLineRunner { public void run(String... args) throws Exception { Document document1 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("key1", "value1")); Document document2 = new Document(UUID.randomUUID().toString(), "Sample content2", Map.of("key2", "value2")); - vectorStore.add(List.of(document1, document2)); - List results = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + this.vectorStore.add(List.of(document1, document2)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); log.info("Search results: {}", results); // Remove the documents from the vector store - vectorStore.delete(List.of(document1.getId(), document2.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId())); } @Bean @@ -133,15 +133,15 @@ metadata2.put("country", "NL"); metadata2.put("year", 2022); metadata2.put("city", "Amsterdam"); -Document document1 = new Document("1", "A document about the UK", metadata1); -Document document2 = new Document("2", "A document about the Netherlands", metadata2); +Document document1 = new Document("1", "A document about the UK", this.metadata1); +Document document2 = new Document("2", "A document about the Netherlands", this.metadata2); vectorStore.add(List.of(document1, document2)); FilterExpressionBuilder builder = new FilterExpressionBuilder(); List results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) - .withFilterExpression((builder.in("country", "UK", "NL")).build())); + .withFilterExpression((this.builder.in("country", "UK", "NL")).build())); ---- == Setting up Azure Cosmos DB Vector Store without Auto Configuration @@ -190,9 +190,9 @@ public class DemoApplication implements CommandLineRunner { public void run(String... args) throws Exception { Document document1 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("key1", "value1")); Document document2 = new Document(UUID.randomUUID().toString(), "Sample content2", Map.of("key2", "value2")); - vectorStore.add(List.of(document1, document2)); + this.vectorStore.add(List.of(document1, document2)); - List results = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); log.info("Search results: {}", results); } @@ -216,7 +216,7 @@ public class DemoApplication implements CommandLineRunner { .gatewayMode() .buildAsyncClient(); - return new CosmosDBVectorStore(observationRegistry, null, cosmosClient, config, embeddingModel); + return new CosmosDBVectorStore(observationRegistry, null, cosmosClient, config, this.embeddingModel); } @Bean diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc index 5c477a2aefb..0bd1e5558b7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc @@ -101,7 +101,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc index 5f2b0cbf7ba..fab97b3f0f6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc @@ -101,7 +101,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[elasticsearchvector-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/hana.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/hana.adoc index 7cebf79e209..9718debdb5d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/hana.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/hana.adoc @@ -183,7 +183,7 @@ public class CricketWorldCupRepository implements HanaVectorRepository, List> splitter = new TokenTextSplitter(); List documents = splitter.apply(reader.get()); log.info("{} documents created from pdf file: {}", documents.size(), pdf.getFilename()); - hanaCloudVectorStore.accept(documents); + this.hanaCloudVectorStore.accept(documents); return ResponseEntity.ok().body(String.format("%d documents created from pdf file: %s", documents.size(), pdf.getFilename())); } @@ -304,7 +304,7 @@ public class CricketWorldCupHanaController { var userMessage = new UserMessage(message); Prompt prompt = new Prompt(List.of(similarDocsMessage, userMessage)); - String generation = chatModel.call(prompt).getResult().getOutput().getContent(); + String generation = this.chatModel.call(prompt).getResult().getOutput().getContent(); log.info("Generation: {}", generation); return Map.of("generation", generation); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc index 939a00aaebe..7675196e49f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc @@ -85,7 +85,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Manual Configuration diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc index b125371a649..af68e343d85 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc @@ -252,7 +252,7 @@ List results = vectorStore.similaritySearch( .withQuery("learn how to grow things") .withTopK(2) .withSimilarityThreshold(0.5) - .withFilterExpression(b.eq("author", "A").build()) + .withFilterExpression(this.b.eq("author", "A").build()) ); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/opensearch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/opensearch.adoc index d7437d89c5e..84bf23db101 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/opensearch.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/opensearch.adoc @@ -109,7 +109,7 @@ List documents = List.of( vectorStore.add(List.of(document)); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/oracle.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/oracle.adoc index 70271ae6ef1..7817d120264 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/oracle.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/oracle.adoc @@ -91,7 +91,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[oracle-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pgvector.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pgvector.adoc index e5c58458ede..2380f0f1c24 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pgvector.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pgvector.adoc @@ -125,7 +125,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[pgvector-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pinecone.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pinecone.adoc index 31c493ced83..4ca3f2fa7c3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pinecone.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pinecone.adoc @@ -96,7 +96,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc index 573cc1068ef..6df3e7ee1ed 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc @@ -98,7 +98,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[qdrant-vectorstore-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc index 04d16a93ed5..cd6337ba78d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc @@ -96,7 +96,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/typesense.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/typesense.adoc index f9ac8641986..6c655601ca3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/typesense.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/typesense.adoc @@ -89,7 +89,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/weaviate.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/weaviate.adoc index 0bfd2e661e3..5c1ce1be0fc 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/weaviate.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/weaviate.adoc @@ -102,7 +102,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[weaviate-vectorstore-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc index c6189f8984f..ea7f72f7546 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc @@ -58,7 +58,7 @@ public class OldSimpleAiController { @GetMapping("/ai/simple") Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatClient.call(message)); + return Map.of("generation", this.chatClient.call(message)); } } ``` @@ -77,7 +77,7 @@ public class SimpleAiController { @GetMapping("/ai/simple") Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } } ``` @@ -109,7 +109,7 @@ class OldSimpleAiController { Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { return Map.of( "generation", - chatClient.call(message) + this.chatClient.call(message) ); } } @@ -131,7 +131,7 @@ class SimpleAiController { Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { return Map.of( "generation", - chatClient.prompt().user(message).call().content() + this.chatClient.prompt().user(message).call().content() ); } } diff --git a/spring-ai-docs/src/main/javadoc/overview.html b/spring-ai-docs/src/main/javadoc/overview.html index 47ee30f0d45..7fb094d7ad9 100644 --- a/spring-ai-docs/src/main/javadoc/overview.html +++ b/spring-ai-docs/src/main/javadoc/overview.html @@ -1,3 +1,19 @@ + +

    diff --git a/spring-ai-retry/pom.xml b/spring-ai-retry/pom.xml index b5008cd6572..848ac898383 100644 --- a/spring-ai-retry/pom.xml +++ b/spring-ai-retry/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java index fcd824693c3..44c405ca6d8 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.retry; /** diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java index bbc584335c3..53d99b1db96 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.retry; import java.io.IOException; @@ -40,21 +41,6 @@ */ public abstract class RetryUtils { - private static final Logger logger = LoggerFactory.getLogger(RetryUtils.class); - - public static final RetryTemplate DEFAULT_RETRY_TEMPLATE = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(TransientAiException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) - .withListener(new RetryListener() { - @Override - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; - }) - .build(); - public static final ResponseErrorHandler DEFAULT_RESPONSE_ERROR_HANDLER = new ResponseErrorHandler() { @Override @@ -81,4 +67,20 @@ public void handleError(@NonNull ClientHttpResponse response) throws IOException } }; + private static final Logger logger = LoggerFactory.getLogger(RetryUtils.class); + + public static final RetryTemplate DEFAULT_RETRY_TEMPLATE = RetryTemplate.builder() + .maxAttempts(10) + .retryOn(TransientAiException.class) + .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) + .withListener(new RetryListener() { + + @Override + public void onError(RetryContext context, + RetryCallback callback, Throwable throwable) { + logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); + } + }) + .build(); + } diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java index 94a7104840f..95b6e37f668 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.retry; /** diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 2a738a8d475..200acce9ee1 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java index ff311b7cab4..a49203d29b8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; @@ -39,8 +42,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Thomas Vitale diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicChatProperties.java index b83ba45401f..2a77786896d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic; import org.springframework.ai.anthropic.AnthropicChatModel; @@ -52,12 +53,12 @@ public AnthropicChatOptions getOptions() { return this.options; } - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public boolean isEnabled() { return this.enabled; } + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicConnectionProperties.java index 3ad6e4ed9e5..53a74e35178 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic; import org.springframework.ai.anthropic.api.AnthropicApi; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAudioTranscriptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAudioTranscriptionProperties.java index c223713e8c3..b3b10416fbc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAudioTranscriptionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAudioTranscriptionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions; @@ -36,7 +37,7 @@ public class AzureOpenAiAudioTranscriptionProperties { private AzureOpenAiAudioTranscriptionOptions options = AzureOpenAiAudioTranscriptionOptions.builder().build(); public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { @@ -44,7 +45,7 @@ public void setEnabled(boolean enabled) { } public AzureOpenAiAudioTranscriptionOptions getOptions() { - return options; + return this.options; } public void setOptions(AzureOpenAiAudioTranscriptionOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index ee686d44c27..c40b97461f6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.credential.KeyCredential; +import com.azure.core.credential.TokenCredential; +import com.azure.core.util.ClientOptions; +import com.azure.core.util.Header; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; @@ -39,15 +48,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.core.credential.KeyCredential; -import com.azure.core.credential.TokenCredential; -import com.azure.core.util.ClientOptions; -import com.azure.core.util.Header; - -import io.micrometer.observation.ObservationRegistry; - /** * @author Piotr Olaszewski * @author Soby Chacko diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java index 7ae5ebc8d6b..58521d28cfa 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java index 16a128260ea..6ede4e8b323 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -54,14 +54,14 @@ public void setEndpoint(String endpoint) { this.endpoint = endpoint; } - public void setApiKey(String apiKey) { - this.apiKey = apiKey; - } - public String getApiKey() { return this.apiKey; } + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + public String getOpenAiApiKey() { return this.openAiApiKey; } @@ -71,7 +71,7 @@ public void setOpenAiApiKey(String openAiApiKey) { } public Map getCustomHeaders() { - return customHeaders; + return this.customHeaders; } public void setCustomHeaders(Map customHeaders) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java index d7eb357bacc..eb88e4f6f39 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingOptions; @@ -39,7 +40,7 @@ public class AzureOpenAiEmbeddingProperties { private MetadataMode metadataMode = MetadataMode.EMBED; public AzureOpenAiEmbeddingOptions getOptions() { - return options; + return this.options; } public void setOptions(AzureOpenAiEmbeddingOptions options) { @@ -48,7 +49,7 @@ public void setOptions(AzureOpenAiEmbeddingOptions options) { } public MetadataMode getMetadataMode() { - return metadataMode; + return this.metadataMode; } public void setMetadataMode(MetadataMode metadataMode) { @@ -57,7 +58,7 @@ public void setMetadataMode(MetadataMode metadataMode) { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiImageOptionsProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiImageOptionsProperties.java index 26e1ae2c89c..4aea1459f4c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiImageOptionsProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiImageOptionsProperties.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.azure.openai; import org.springframework.ai.azure.openai.AzureOpenAiImageOptions; @@ -24,7 +40,7 @@ public class AzureOpenAiImageOptionsProperties { private AzureOpenAiImageOptions options = AzureOpenAiImageOptions.builder().build(); public AzureOpenAiImageOptions getOptions() { - return options; + return this.options; } public void setOptions(AzureOpenAiImageOptions options) { @@ -32,7 +48,7 @@ public void setOptions(AzureOpenAiImageOptions options) { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java index 46ee804056d..10db7578206 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionProperties.java index da8d3f34f0f..6d5338366e2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.bedrock; -import org.springframework.boot.context.properties.ConfigurationProperties; +package org.springframework.ai.autoconfigure.bedrock; import java.time.Duration; +import org.springframework.boot.context.properties.ConfigurationProperties; + /** * Configuration properties for Bedrock AWS connection. * @@ -51,7 +52,7 @@ public class BedrockAwsConnectionProperties { private Duration timeout = Duration.ofMinutes(5L); public String getRegion() { - return region; + return this.region; } public void setRegion(String awsRegion) { @@ -59,7 +60,7 @@ public void setRegion(String awsRegion) { } public String getAccessKey() { - return accessKey; + return this.accessKey; } public void setAccessKey(String accessKey) { @@ -67,7 +68,7 @@ public void setAccessKey(String accessKey) { } public String getSecretKey() { - return secretKey; + return this.secretKey; } public void setSecretKey(String secretKey) { @@ -75,7 +76,7 @@ public void setSecretKey(String secretKey) { } public Duration getTimeout() { - return timeout; + return this.timeout; } public void setTimeout(Duration timeout) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java index 5d3f2d5fb07..77a8912839b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatModel; @@ -28,8 +32,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java index daee6365bbe..5f25196005c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic; import java.util.List; @@ -70,7 +71,7 @@ public void setModel(String model) { } public AnthropicChatOptions getOptions() { - return options; + return this.options; } public void setOptions(AnthropicChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java index ac4f788355e..f385c18acef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic3; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; @@ -28,8 +32,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java index 96ddc3e06b2..ef981bba4aa 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic3; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; import org.springframework.ai.bedrock.anthropic3.Anthropic3ChatOptions; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; @@ -70,7 +71,7 @@ public void setModel(String model) { } public Anthropic3ChatOptions getOptions() { - return options; + return this.options; } public void setOptions(Anthropic3ChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java index 706474ca35c..95e1a18964f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel; @@ -28,8 +32,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Chat Client. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java index 0381d591b9f..723c91ef228 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import org.springframework.ai.bedrock.cohere.BedrockCohereChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java index 27b1edd44a5..86ba3f76b3d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java index 8e6327941e1..3841fb6f38a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java index ac46a66f6ae..b84fbb11279 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -17,6 +17,9 @@ package org.springframework.ai.autoconfigure.bedrock.jurrasic2; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel; @@ -29,8 +32,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Jurassic2 Chat Client. @@ -68,4 +69,4 @@ public BedrockAi21Jurassic2ChatModel jurassic2ChatModel(Ai21Jurassic2ChatBedrock .build(); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java index 183c050bcfb..6ed3f46181c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java index 59341b7f17e..b97204ea35d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.llama; import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.providers.AwsRegionProvider; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java index c58742ee42b..979e4fae472 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.llama; import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java index 639085cb756..c6b0e19ae99 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.titan.BedrockTitanChatModel; @@ -28,8 +32,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Titan Chat Client. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java index 4b6df741ab3..d54327a6d88 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions; @@ -45,7 +46,7 @@ public class BedrockTitanChatProperties { private BedrockTitanChatOptions options = BedrockTitanChatOptions.builder().withTemperature(0.7).build(); public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { @@ -53,7 +54,7 @@ public void setEnabled(boolean enabled) { } public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -61,7 +62,7 @@ public void setModel(String model) { } public BedrockTitanChatOptions getOptions() { - return options; + return this.options; } public void setOptions(BedrockTitanChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java index 37b63e1ab6c..96c6cfa8c18 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java index 5136d757504..b0c1e604861 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel.InputType; @@ -46,8 +47,12 @@ public class BedrockTitanEmbeddingProperties { */ private InputType inputType = InputType.IMAGE; + public static String getConfigPrefix() { + return CONFIG_PREFIX; + } + public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { @@ -55,23 +60,19 @@ public void setEnabled(boolean enabled) { } public String getModel() { - return model; + return this.model; } public void setModel(String model) { this.model = model; } - public static String getConfigPrefix() { - return CONFIG_PREFIX; + public InputType getInputType() { + return this.inputType; } public void setInputType(InputType inputType) { this.inputType = inputType; } - public InputType getInputType() { - return inputType; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java index 9e5a9185f39..b1d9f9aa923 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,10 @@ package org.springframework.ai.autoconfigure.chat.client; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClientCustomizer; import org.springframework.ai.chat.client.observation.ChatClientInputContentObservationFilter; @@ -33,8 +35,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Scope; -import io.micrometer.observation.ObservationRegistry; - /** * {@link EnableAutoConfiguration Auto-configuration} for {@link ChatClient}. *

    diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderConfigurer.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderConfigurer.java index 02c55ee5257..a59653855f7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderConfigurer.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderProperties.java index 102c5b745dd..91065c18904 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.client; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -58,7 +59,7 @@ public static class Observations { private boolean includeInput = false; public boolean isIncludeInput() { - return includeInput; + return this.includeInput; } public void setIncludeInput(boolean includeCompletion) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/CommonChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/CommonChatMemoryProperties.java index a635c70f6fa..9cd3c915241 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/CommonChatMemoryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/CommonChatMemoryProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory; /** @@ -24,7 +25,7 @@ public class CommonChatMemoryProperties { private boolean initializeSchema = true; public boolean isInitializeSchema() { - return initializeSchema; + return this.initializeSchema; } public void setInitializeSchema(boolean initializeSchema) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java index b84011dfe8b..9e5cfac9448 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory.cassandra; import com.datastax.oss.driver.api.core.CqlSession; + import org.springframework.ai.chat.memory.CassandraChatMemory; import org.springframework.ai.chat.memory.CassandraChatMemoryConfig; - import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java index 91d2252bcc3..fc0f45e5ce3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory.cassandra; import java.time.Duration; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -62,7 +64,7 @@ public void setTable(String table) { } public String getAssistantColumn() { - return assistantColumn; + return this.assistantColumn; } public void setAssistantColumn(String assistantColumn) { @@ -70,7 +72,7 @@ public void setAssistantColumn(String assistantColumn) { } public String getUserColumn() { - return userColumn; + return this.userColumn; } public void setUserColumn(String userColumn) { @@ -79,7 +81,7 @@ public void setUserColumn(String userColumn) { @Nullable public Duration getTimeToLiveSeconds() { - return timeToLiveSeconds; + return this.timeToLiveSeconds; } public void setTimeToLiveSeconds(Duration timeToLiveSeconds) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfiguration.java index 278ebd35c94..f439263fdad 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.observation; +import java.util.List; + import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.tracing.Tracer; import io.micrometer.tracing.otel.bridge.OtelTracer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; import org.springframework.ai.chat.model.ChatModel; @@ -44,8 +48,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import java.util.List; - /** * Auto-configuration for Spring AI chat model observations. * @@ -60,6 +62,16 @@ public class ChatObservationAutoConfiguration { private static final Logger logger = LoggerFactory.getLogger(ChatObservationAutoConfiguration.class); + private static void logPromptContentWarning() { + logger.warn( + "You have enabled the inclusion of the prompt content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); + } + + private static void logCompletionWarning() { + logger.warn( + "You have enabled the inclusion of the completion content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); + } + @Bean @ConditionalOnMissingBean @ConditionalOnBean(MeterRegistry.class) @@ -141,14 +153,4 @@ public ErrorLoggingObservationHandler errorLoggingObservationHandler(Tracer trac } - private static void logPromptContentWarning() { - logger.warn( - "You have enabled the inclusion of the prompt content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); - } - - private static void logCompletionWarning() { - logger.warn( - "You have enabled the inclusion of the completion content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationProperties.java index 750f4911265..cd353ac0b97 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.observation; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -44,7 +45,7 @@ public class ChatObservationProperties { private boolean includeErrorLogging = false; public boolean isIncludeCompletion() { - return includeCompletion; + return this.includeCompletion; } public void setIncludeCompletion(boolean includeCompletion) { @@ -52,7 +53,7 @@ public void setIncludeCompletion(boolean includeCompletion) { } public boolean isIncludePrompt() { - return includePrompt; + return this.includePrompt; } public void setIncludePrompt(boolean includePrompt) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/package-info.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/package-info.java index 5d159e12a7a..1a9623e32cd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/package-info.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfiguration.java index e2c8f62aac1..ba1d61886f2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.embedding.observation; import io.micrometer.core.instrument.MeterRegistry; + import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler; import org.springframework.beans.factory.ObjectProvider; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/package-info.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/package-info.java index 1d7239f591b..dfba7c66181 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/package-info.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfiguration.java index ef28fe2b3b2..eaf74d75f86 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.huggingface; import org.springframework.ai.huggingface.HuggingfaceChatModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatProperties.java index e64fff5845f..cf844436d56 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.huggingface; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -44,7 +45,7 @@ public class HuggingfaceChatProperties { private boolean enabled = true; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -52,7 +53,7 @@ public void setApiKey(String apiKey) { } public String getUrl() { - return url; + return this.url; } public void setUrl(String url) { @@ -60,7 +61,7 @@ public void setUrl(String url) { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfiguration.java index e54333f1a60..a89a42102a3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.image.observation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.observation.ImageModelPromptContentObservationFilter; import org.springframework.boot.autoconfigure.AutoConfiguration; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationProperties.java index 5663e285427..3e454ee8d20 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.image.observation; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -34,7 +35,7 @@ public class ImageObservationProperties { private boolean includePrompt = false; public boolean isIncludePrompt() { - return includePrompt; + return this.includePrompt; } public void setIncludePrompt(boolean includePrompt) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/package-info.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/package-info.java index f95f9e63c69..019b02af129 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/package-info.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java index e4c549140f0..6b7c8307c6c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -40,8 +43,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geng Rong */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxChatProperties.java index 5ca297949a1..df7b1f0fe78 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import org.springframework.ai.minimax.MiniMaxChatOptions; @@ -44,7 +45,7 @@ public class MiniMaxChatProperties extends MiniMaxParentProperties { .build(); public MiniMaxChatOptions getOptions() { - return options; + return this.options; } public void setOptions(MiniMaxChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxConnectionProperties.java index 1019e849949..59d5ff0cadf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxEmbeddingProperties.java index bfdb49174f6..cb40233315c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import org.springframework.ai.document.MetadataMode; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxParentProperties.java index 1f8f9f6b722..34b98c0c122 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; /** @@ -25,7 +26,7 @@ class MiniMaxParentProperties { private String baseUrl; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -33,7 +34,7 @@ public void setApiKey(String apiKey) { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java index be572d81919..95241a56917 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -42,8 +45,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Ricken Bazolo * @author Christian Tzolov diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java index 9e46cc8307c..af39bc37454 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import org.springframework.ai.mistralai.MistralAiChatOptions; @@ -39,10 +40,6 @@ public class MistralAiChatProperties extends MistralAiParentProperties { private static final Boolean IS_ENABLED = false; - public MistralAiChatProperties() { - super.setBaseUrl(MistralAiCommonProperties.DEFAULT_BASE_URL); - } - /** * Enable OpenAI chat model. */ @@ -56,6 +53,10 @@ public MistralAiChatProperties() { .withTopP(DEFAULT_TOP_P) .build(); + public MistralAiChatProperties() { + super.setBaseUrl(MistralAiCommonProperties.DEFAULT_BASE_URL); + } + public MistralAiChatOptions getOptions() { return this.options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiCommonProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiCommonProperties.java index 0bc7132be21..54023eb6362 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiCommonProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiCommonProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiEmbeddingProperties.java index 450ac479ab3..633f3609739 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import org.springframework.ai.document.MetadataMode; @@ -34,13 +35,13 @@ public class MistralAiEmbeddingProperties extends MistralAiParentProperties { public static final String DEFAULT_ENCODING_FORMAT = "float"; + public MetadataMode metadataMode = MetadataMode.EMBED; + /** * Enable MistralAI embedding model. */ private boolean enabled = true; - public MetadataMode metadataMode = MetadataMode.EMBED; - @NestedConfigurationProperty private MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder() .withModel(DEFAULT_EMBEDDING_MODEL) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiParentProperties.java index 31c632af80a..f2d398a5847 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; /** diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfiguration.java index 3bd4223f416..cc3877fdd81 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.FunctionCallback; @@ -38,8 +41,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geng Rong */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotChatProperties.java index 91918ff00c2..299648fe176 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; import org.springframework.ai.moonshot.MoonshotChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotCommonProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotCommonProperties.java index 07525a3113a..953f4137930 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotCommonProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotCommonProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotParentProperties.java index 54f3f5f3b54..ed27cfba4a7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; /** diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIConnectionProperties.java index 6de993705da..d64c88f9ac2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; import java.nio.file.Paths; @@ -26,26 +27,9 @@ @ConfigurationProperties(OCIConnectionProperties.CONFIG_PREFIX) public class OCIConnectionProperties { - private static final String DEFAULT_PROFILE = "DEFAULT"; - public static final String CONFIG_PREFIX = "spring.ai.oci.genai"; - public enum AuthenticationType { - - FILE("file"), INSTANCE_PRINCIPAL("instance-principal"), WORKLOAD_IDENTITY("workload-identity"), - SIMPLE("simple"); - - private final String authType; - - AuthenticationType(String authType) { - this.authType = authType; - } - - public String getAuthType() { - return this.authType; - } - - } + private static final String DEFAULT_PROFILE = "DEFAULT"; private AuthenticationType authenticationType = AuthenticationType.FILE; @@ -68,7 +52,7 @@ public String getAuthType() { private String endpoint; public String getRegion() { - return region; + return this.region; } public void setRegion(String region) { @@ -76,7 +60,7 @@ public void setRegion(String region) { } public String getPassPhrase() { - return passPhrase; + return this.passPhrase; } public void setPassPhrase(String passPhrase) { @@ -84,7 +68,7 @@ public void setPassPhrase(String passPhrase) { } public String getPrivateKey() { - return privateKey; + return this.privateKey; } public void setPrivateKey(String privateKey) { @@ -92,7 +76,7 @@ public void setPrivateKey(String privateKey) { } public String getFingerprint() { - return fingerprint; + return this.fingerprint; } public void setFingerprint(String fingerprint) { @@ -100,7 +84,7 @@ public void setFingerprint(String fingerprint) { } public String getUserId() { - return userId; + return this.userId; } public void setUserId(String userId) { @@ -108,7 +92,7 @@ public void setUserId(String userId) { } public String getTenantId() { - return tenantId; + return this.tenantId; } public void setTenantId(String tenantId) { @@ -116,7 +100,7 @@ public void setTenantId(String tenantId) { } public String getFile() { - return file; + return this.file; } public void setFile(String file) { @@ -124,7 +108,7 @@ public void setFile(String file) { } public String getProfile() { - return StringUtils.hasText(profile) ? profile : DEFAULT_PROFILE; + return StringUtils.hasText(this.profile) ? this.profile : DEFAULT_PROFILE; } public void setProfile(String profile) { @@ -132,7 +116,7 @@ public void setProfile(String profile) { } public AuthenticationType getAuthenticationType() { - return authenticationType; + return this.authenticationType; } public void setAuthenticationType(AuthenticationType authenticationType) { @@ -140,11 +124,28 @@ public void setAuthenticationType(AuthenticationType authenticationType) { } public String getEndpoint() { - return endpoint; + return this.endpoint; } public void setEndpoint(String endpoint) { this.endpoint = endpoint; } + public enum AuthenticationType { + + FILE("file"), INSTANCE_PRINCIPAL("instance-principal"), WORKLOAD_IDENTITY("workload-identity"), + SIMPLE("simple"); + + private final String authType; + + AuthenticationType(String authType) { + this.authType = authType; + } + + public String getAuthType() { + return this.authType; + } + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIEmbeddingModelProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIEmbeddingModelProperties.java index cc2ef2b5f11..034836a253c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIEmbeddingModelProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIEmbeddingModelProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; + import org.springframework.ai.oci.OCIEmbeddingOptions; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -39,15 +41,15 @@ public class OCIEmbeddingModelProperties { public OCIEmbeddingOptions getEmbeddingOptions() { return OCIEmbeddingOptions.builder() - .withCompartment(compartment) - .withModel(model) - .withServingMode(servingMode.getMode()) - .withTruncate(truncate) + .withCompartment(this.compartment) + .withModel(this.model) + .withServingMode(this.servingMode.getMode()) + .withTruncate(this.truncate) .build(); } public ServingMode getServingMode() { - return servingMode; + return this.servingMode; } public void setServingMode(ServingMode servingMode) { @@ -55,7 +57,7 @@ public void setServingMode(ServingMode servingMode) { } public String getCompartment() { - return compartment; + return this.compartment; } public void setCompartment(String compartment) { @@ -63,7 +65,7 @@ public void setCompartment(String compartment) { } public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -71,7 +73,7 @@ public void setModel(String model) { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { @@ -79,7 +81,7 @@ public void setEnabled(boolean enabled) { } public EmbedTextDetails.Truncate getTruncate() { - return truncate; + return this.truncate; } public void setTruncate(EmbedTextDetails.Truncate truncate) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java index 9e8a64059f5..681ee71f3c3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; import java.io.IOException; @@ -27,6 +28,7 @@ import com.oracle.bmc.auth.okeworkloadidentity.OkeWorkloadIdentityAuthenticationDetailsProvider; import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; import com.oracle.bmc.retrier.RetryConfiguration; + import org.springframework.ai.oci.OCIEmbeddingModel; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -44,6 +46,23 @@ @EnableConfigurationProperties({ OCIConnectionProperties.class, OCIEmbeddingModelProperties.class }) public class OCIGenAiAutoConfiguration { + private static BasicAuthenticationDetailsProvider authenticationProvider(OCIConnectionProperties properties) + throws IOException { + return switch (properties.getAuthenticationType()) { + case FILE -> new ConfigFileAuthenticationDetailsProvider(properties.getFile(), properties.getProfile()); + case INSTANCE_PRINCIPAL -> InstancePrincipalsAuthenticationDetailsProvider.builder().build(); + case WORKLOAD_IDENTITY -> OkeWorkloadIdentityAuthenticationDetailsProvider.builder().build(); + case SIMPLE -> SimpleAuthenticationDetailsProvider.builder() + .userId(properties.getUserId()) + .tenantId(properties.getTenantId()) + .fingerprint(properties.getFingerprint()) + .privateKeySupplier(new SimplePrivateKeySupplier(properties.getPrivateKey())) + .passPhrase(properties.getPassPhrase()) + .region(Region.valueOf(properties.getRegion())) + .build(); + }; + } + @ConditionalOnMissingBean @Bean public GenerativeAiInferenceClient generativeAiInferenceClient(OCIConnectionProperties properties) @@ -70,21 +89,4 @@ public OCIEmbeddingModel ociEmbeddingModel(GenerativeAiInferenceClient generativ return new OCIEmbeddingModel(generativeAiClient, properties.getEmbeddingOptions()); } - private static BasicAuthenticationDetailsProvider authenticationProvider(OCIConnectionProperties properties) - throws IOException { - return switch (properties.getAuthenticationType()) { - case FILE -> new ConfigFileAuthenticationDetailsProvider(properties.getFile(), properties.getProfile()); - case INSTANCE_PRINCIPAL -> InstancePrincipalsAuthenticationDetailsProvider.builder().build(); - case WORKLOAD_IDENTITY -> OkeWorkloadIdentityAuthenticationDetailsProvider.builder().build(); - case SIMPLE -> SimpleAuthenticationDetailsProvider.builder() - .userId(properties.getUserId()) - .tenantId(properties.getTenantId()) - .fingerprint(properties.getFingerprint()) - .privateKeySupplier(new SimplePrivateKeySupplier(properties.getPrivateKey())) - .passPhrase(properties.getPassPhrase()) - .region(Region.valueOf(properties.getRegion())) - .build(); - }; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/ServingMode.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/ServingMode.java index 7cb2299a2c8..291a056b44c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/ServingMode.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/ServingMode.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; /** @@ -29,7 +30,7 @@ public enum ServingMode { } public String getMode() { - return mode; + return this.mode; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java index 453c237266c..9b86be15097 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.model.function.FunctionCallback; @@ -40,8 +43,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import io.micrometer.observation.ObservationRegistry; - /** * {@link AutoConfiguration Auto-configuration} for Ollama Chat Client. * @@ -124,6 +125,14 @@ public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbe return embeddingModel; } + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + static class PropertiesOllamaConnectionDetails implements OllamaConnectionDetails { private final OllamaConnectionProperties properties; @@ -139,12 +148,4 @@ public String getBaseUrl() { } - @Bean - @ConditionalOnMissingBean - public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { - FunctionCallbackContext manager = new FunctionCallbackContext(); - manager.setApplicationContext(context); - return manager; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java index c106c8c359d..ef60a94dc0a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.springframework.ai.ollama.api.OllamaModel; @@ -56,12 +57,12 @@ public OllamaOptions getOptions() { return this.options; } - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public boolean isEnabled() { return this.enabled; } + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionDetails.java index 6981097c3e8..9e392486f30 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionProperties.java index 160849ed89e..46f127e1310 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -34,7 +35,7 @@ public class OllamaConnectionProperties { private String baseUrl = "http://localhost:11434"; public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java index 9b21a92d5b6..b159fe7be27 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.springframework.ai.ollama.api.OllamaModel; @@ -56,12 +57,12 @@ public OllamaOptions getOptions() { return this.options; } - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public boolean isEnabled() { return this.enabled; } + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java index b884404be55..54c764f274c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.ollama; -import org.springframework.ai.ollama.management.PullModelStrategy; -import org.springframework.boot.context.properties.ConfigurationProperties; +package org.springframework.ai.autoconfigure.ollama; import java.time.Duration; import java.util.List; +import org.springframework.ai.ollama.management.PullModelStrategy; +import org.springframework.boot.context.properties.ConfigurationProperties; + /** * Ollama initialization configuration properties. * @@ -32,11 +33,6 @@ public class OllamaInitializationProperties { public static final String CONFIG_PREFIX = "spring.ai.ollama.init"; - /** - * Whether to pull models at startup-time and how. - */ - private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER; - /** * Chat models initialization settings. */ @@ -47,6 +43,11 @@ public class OllamaInitializationProperties { */ private final ModelTypeInit embedding = new ModelTypeInit(); + /** + * Whether to pull models at startup-time and how. + */ + private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER; + /** * How long to wait for a model to be pulled. */ @@ -58,7 +59,7 @@ public class OllamaInitializationProperties { private int maxRetries = 0; public PullModelStrategy getPullModelStrategy() { - return pullModelStrategy; + return this.pullModelStrategy; } public void setPullModelStrategy(PullModelStrategy pullModelStrategy) { @@ -66,15 +67,15 @@ public void setPullModelStrategy(PullModelStrategy pullModelStrategy) { } public ModelTypeInit getChat() { - return chat; + return this.chat; } public ModelTypeInit getEmbedding() { - return embedding; + return this.embedding; } public Duration getTimeout() { - return timeout; + return this.timeout; } public void setTimeout(Duration timeout) { @@ -82,7 +83,7 @@ public void setTimeout(Duration timeout) { } public int getMaxRetries() { - return maxRetries; + return this.maxRetries; } public void setMaxRetries(int maxRetries) { @@ -103,7 +104,7 @@ public static class ModelTypeInit { private List additionalModels = List.of(); public boolean isInclude() { - return include; + return this.include; } public void setInclude(boolean include) { @@ -111,7 +112,7 @@ public void setInclude(boolean include) { } public List getAdditionalModels() { - return additionalModels; + return this.additionalModels; } public void setAdditionalModels(List additionalModels) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java index 8d583c20576..f7038e67978 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -57,7 +57,7 @@ public class OpenAiAudioSpeechProperties extends OpenAiParentProperties { .build(); public OpenAiAudioSpeechOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiAudioSpeechOptions options) { @@ -65,7 +65,7 @@ public void setOptions(OpenAiAudioSpeechOptions options) { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java index e8546e08906..277fed64859 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions; @@ -44,7 +45,7 @@ public class OpenAiAudioTranscriptionProperties extends OpenAiParentProperties { .build(); public OpenAiAudioTranscriptionOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiAudioTranscriptionOptions options) { @@ -52,7 +53,7 @@ public void setOptions(OpenAiAudioTranscriptionOptions options) { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index d594460c5a2..3436ca32696 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import java.util.HashMap; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; import org.jetbrains.annotations.NotNull; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -56,8 +59,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Stefan Vassilev @@ -73,6 +74,36 @@ WebClientAutoConfiguration.class }) public class OpenAiAutoConfiguration { + private static @NotNull ResolvedConnectionProperties resolveConnectionProperties( + OpenAiParentProperties commonProperties, OpenAiParentProperties modelProperties, String modelType) { + + String baseUrl = StringUtils.hasText(modelProperties.getBaseUrl()) ? modelProperties.getBaseUrl() + : commonProperties.getBaseUrl(); + String apiKey = StringUtils.hasText(modelProperties.getApiKey()) ? modelProperties.getApiKey() + : commonProperties.getApiKey(); + String projectId = StringUtils.hasText(modelProperties.getProjectId()) ? modelProperties.getProjectId() + : commonProperties.getProjectId(); + String organizationId = StringUtils.hasText(modelProperties.getOrganizationId()) + ? modelProperties.getOrganizationId() : commonProperties.getOrganizationId(); + + Map> connectionHeaders = new HashMap<>(); + if (StringUtils.hasText(projectId)) { + connectionHeaders.put("OpenAI-Project", List.of(projectId)); + } + if (StringUtils.hasText(organizationId)) { + connectionHeaders.put("OpenAI-Organization", List.of(organizationId)); + } + + Assert.hasText(baseUrl, + "OpenAI base URL must be set. Use the connection property: spring.ai.openai.base-url or spring.ai.openai." + + modelType + ".base-url property."); + Assert.hasText(apiKey, + "OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai." + + modelType + ".api-key property."); + + return new ResolvedConnectionProperties(baseUrl, apiKey, CollectionUtils.toMultiValueMap(connectionHeaders)); + } + @Bean @ConditionalOnMissingBean @ConditionalOnProperty(prefix = OpenAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", @@ -229,37 +260,8 @@ public FunctionCallbackContext springAiFunctionManager(ApplicationContext contex return manager; } - private static @NotNull ResolvedConnectionProperties resolveConnectionProperties( - OpenAiParentProperties commonProperties, OpenAiParentProperties modelProperties, String modelType) { - - String baseUrl = StringUtils.hasText(modelProperties.getBaseUrl()) ? modelProperties.getBaseUrl() - : commonProperties.getBaseUrl(); - String apiKey = StringUtils.hasText(modelProperties.getApiKey()) ? modelProperties.getApiKey() - : commonProperties.getApiKey(); - String projectId = StringUtils.hasText(modelProperties.getProjectId()) ? modelProperties.getProjectId() - : commonProperties.getProjectId(); - String organizationId = StringUtils.hasText(modelProperties.getOrganizationId()) - ? modelProperties.getOrganizationId() : commonProperties.getOrganizationId(); - - Map> connectionHeaders = new HashMap<>(); - if (StringUtils.hasText(projectId)) { - connectionHeaders.put("OpenAI-Project", List.of(projectId)); - } - if (StringUtils.hasText(organizationId)) { - connectionHeaders.put("OpenAI-Organization", List.of(organizationId)); - } - - Assert.hasText(baseUrl, - "OpenAI base URL must be set. Use the connection property: spring.ai.openai.base-url or spring.ai.openai." - + modelType + ".base-url property."); - Assert.hasText(apiKey, - "OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai." - + modelType + ".api-key property."); - - return new ResolvedConnectionProperties(baseUrl, apiKey, CollectionUtils.toMultiValueMap(connectionHeaders)); - } - private record ResolvedConnectionProperties(String baseUrl, String apiKey, MultiValueMap headers) { + } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java index e2014de6b81..007542e53b9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.openai.OpenAiChatOptions; @@ -26,10 +27,10 @@ public class OpenAiChatProperties extends OpenAiParentProperties { public static final String DEFAULT_CHAT_MODEL = "gpt-4o"; - private static final Double DEFAULT_TEMPERATURE = 0.7; - public static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions"; + private static final Double DEFAULT_TEMPERATURE = 0.7; + /** * Enable OpenAI chat model. */ @@ -44,7 +45,7 @@ public class OpenAiChatProperties extends OpenAiParentProperties { .build(); public OpenAiChatOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiChatOptions options) { @@ -60,7 +61,7 @@ public void setEnabled(boolean enabled) { } public String getCompletionsPath() { - return completionsPath; + return this.completionsPath; } public void setCompletionsPath(String completionsPath) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java index b065deb53a5..e6c6f582d1b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java index 008a3c18d8c..7a0e5286fdb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.document.MetadataMode; @@ -68,7 +69,7 @@ public void setEnabled(boolean enabled) { } public String getEmbeddingsPath() { - return embeddingsPath; + return this.embeddingsPath; } public void setEmbeddingsPath(String embeddingsPath) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java index 06fb24bf6a8..7e14567ba39 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.openai.OpenAiImageOptions; @@ -45,7 +46,7 @@ public class OpenAiImageProperties extends OpenAiParentProperties { private OpenAiImageOptions options = OpenAiImageOptions.builder().withModel(DEFAULT_IMAGE_MODEL).build(); public OpenAiImageOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiImageOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiModerationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiModerationProperties.java index d468f591c51..d9e709862ef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiModerationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiModerationProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ public class OpenAiModerationProperties extends OpenAiParentProperties { private OpenAiModerationOptions options = OpenAiModerationOptions.builder().build(); public OpenAiModerationOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiModerationOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java index 79aa3d833c1..7516ba84460 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; /** @@ -32,7 +33,7 @@ class OpenAiParentProperties { private String organizationId; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -40,7 +41,7 @@ public void setApiKey(String apiKey) { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java index ca30501b5f2..b2dc3da2f93 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.postgresml; import org.springframework.ai.postgresml.PostgresMlEmbeddingModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java index 53dba7f9311..d67f944055a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.postgresml; import java.util.Map; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java index 5a2efccae0f..7c7fe359b3d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -40,8 +43,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geng Rong */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java index cd0edcd3d70..9208cd53cec 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.springframework.ai.qianfan.QianFanChatOptions; @@ -44,7 +45,7 @@ public class QianFanChatProperties extends QianFanParentProperties { .build(); public QianFanChatOptions getOptions() { - return options; + return this.options; } public void setOptions(QianFanChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java index ff07a629101..90cb8c7a22f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.springframework.ai.qianfan.api.QianFanConstants; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java index 2a235110097..97091f0b630 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.springframework.ai.document.MetadataMode; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanImageProperties.java index 3d81043cad7..5946747c8e2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanImageProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.springframework.ai.qianfan.QianFanImageOptions; @@ -44,7 +45,7 @@ public class QianFanImageProperties extends QianFanParentProperties { private QianFanImageOptions options = QianFanImageOptions.builder().withModel(DEFAULT_IMAGE_MODEL).build(); public QianFanImageOptions getOptions() { - return options; + return this.options; } public void setOptions(QianFanImageOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java index 109cc279bdc..543bec0c6fe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; /** @@ -27,7 +28,7 @@ class QianFanParentProperties { private String baseUrl; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -35,7 +36,7 @@ public void setApiKey(String apiKey) { } public String getSecretKey() { - return secretKey; + return this.secretKey; } public void setSecretKey(String secretKey) { @@ -43,7 +44,7 @@ public void setSecretKey(String secretKey) { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java index 295ce7ca449..0941200366d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.retry; import java.io.IOException; @@ -57,11 +58,14 @@ public RetryTemplate retryTemplate(SpringAiRetryProperties properties) { .exponentialBackoff(properties.getBackoff().getInitialInterval(), properties.getBackoff().getMultiplier(), properties.getBackoff().getMaxInterval()) .withListener(new RetryListener() { + @Override public void onError(RetryContext context, RetryCallback callback, Throwable throwable) { logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; + } + + ; }) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java index 8b04b81e2af..69f651794ba 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.retry; import java.time.Duration; @@ -59,6 +60,42 @@ public class SpringAiRetryProperties { */ private List onHttpCodes = new ArrayList<>(); + public int getMaxAttempts() { + return this.maxAttempts; + } + + public void setMaxAttempts(int maxAttempts) { + this.maxAttempts = maxAttempts; + } + + public Backoff getBackoff() { + return this.backoff; + } + + public List getExcludeOnHttpCodes() { + return this.excludeOnHttpCodes; + } + + public void setExcludeOnHttpCodes(List onHttpCodes) { + this.excludeOnHttpCodes = onHttpCodes; + } + + public boolean isOnClientErrors() { + return this.onClientErrors; + } + + public void setOnClientErrors(boolean onClientErrors) { + this.onClientErrors = onClientErrors; + } + + public List getOnHttpCodes() { + return this.onHttpCodes; + } + + public void setOnHttpCodes(List onHttpCodes) { + this.onHttpCodes = onHttpCodes; + } + /** * Exponential Backoff properties. */ @@ -80,7 +117,7 @@ public static class Backoff { private Duration maxInterval = Duration.ofMillis(3 * 60000); public Duration getInitialInterval() { - return initialInterval; + return this.initialInterval; } public void setInitialInterval(Duration initialInterval) { @@ -88,7 +125,7 @@ public void setInitialInterval(Duration initialInterval) { } public int getMultiplier() { - return multiplier; + return this.multiplier; } public void setMultiplier(int multiplier) { @@ -96,7 +133,7 @@ public void setMultiplier(int multiplier) { } public Duration getMaxInterval() { - return maxInterval; + return this.maxInterval; } public void setMaxInterval(Duration maxInterval) { @@ -105,40 +142,4 @@ public void setMaxInterval(Duration maxInterval) { } - public int getMaxAttempts() { - return this.maxAttempts; - } - - public void setMaxAttempts(int maxAttempts) { - this.maxAttempts = maxAttempts; - } - - public Backoff getBackoff() { - return this.backoff; - } - - public List getExcludeOnHttpCodes() { - return this.excludeOnHttpCodes; - } - - public void setExcludeOnHttpCodes(List onHttpCodes) { - this.excludeOnHttpCodes = onHttpCodes; - } - - public boolean isOnClientErrors() { - return this.onClientErrors; - } - - public void setOnClientErrors(boolean onClientErrors) { - this.onClientErrors = onClientErrors; - } - - public List getOnHttpCodes() { - return this.onHttpCodes; - } - - public void setOnHttpCodes(List onHttpCodes) { - this.onHttpCodes = onHttpCodes; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiConnectionProperties.java index 1cf0d557171..e39d36e7f2d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.springframework.ai.stabilityai.api.StabilityAiApi; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java index 0a594ff0bc1..cf5f66cf6dc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.springframework.ai.stabilityai.StabilityAiImageModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageProperties.java index d307750df66..9af35a1f23f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; @@ -36,9 +37,10 @@ public class StabilityAiImageProperties extends StabilityAiParentProperties { @NestedConfigurationProperty private StabilityAiImageOptions options = StabilityAiImageOptions.builder().build(); // stable-diffusion-v1-6 - // is - // default - // model + + // is + // default + // model public StabilityAiImageOptions getOptions() { return this.options; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiParentProperties.java index f8b62cd8ea2..b62d9e5e312 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; /** @@ -28,7 +29,7 @@ class StabilityAiParentProperties { private String baseUrl; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -36,7 +37,7 @@ public void setApiKey(String apiKey) { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java index 583d568631c..482e733f93b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.transformers; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.onnxruntime.OrtSession; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.beans.factory.ObjectProvider; @@ -25,10 +30,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import ai.onnxruntime.OrtSession; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java index 2ffc590bedd..67e0922b65e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.transformers; import java.io.File; @@ -42,11 +43,33 @@ public class TransformersEmbeddingModelProperties { "spring-ai-onnx-generative") .getAbsolutePath(); + @NestedConfigurationProperty + private final Tokenizer tokenizer = new Tokenizer(); + + /** + * Controls caching of remote, large resources to local file system. + */ + @NestedConfigurationProperty + private final Cache cache = new Cache(); + + @NestedConfigurationProperty + private final Onnx onnx = new Onnx(); + /** * Enable the Transformer Embedding model. */ private boolean enabled = true; + /** + * Specifies what parts of the {@link Document}'s content and metadata will be used + * for computing the embeddings. Applicable for the + * {@link TransformersEmbeddingModel#embed(Document)} method only. Has no effect on + * the {@link TransformersEmbeddingModel#embed(String)} or + * {@link TransformersEmbeddingModel#embed(List)}. Defaults to + * {@link MetadataMode#NONE}. + */ + private MetadataMode metadataMode = MetadataMode.NONE; + public boolean isEnabled() { return this.enabled; } @@ -55,6 +78,26 @@ public void setEnabled(boolean enabled) { this.enabled = enabled; } + public Cache getCache() { + return this.cache; + } + + public Onnx getOnnx() { + return this.onnx; + } + + public Tokenizer getTokenizer() { + return this.tokenizer; + } + + public MetadataMode getMetadataMode() { + return this.metadataMode; + } + + public void setMetadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + } + /** * Configurations for the {@link HuggingFaceTokenizer} used to convert sentences into * tokens. @@ -93,9 +136,6 @@ public void setOptions(Map options) { } - @NestedConfigurationProperty - private final Tokenizer tokenizer = new Tokenizer(); - public static class Cache { /** @@ -128,16 +168,6 @@ public void setDirectory(String directory) { } - /** - * Controls caching of remote, large resources to local file system. - */ - @NestedConfigurationProperty - private final Cache cache = new Cache(); - - public Cache getCache() { - return this.cache; - } - public static class Onnx { /** @@ -186,33 +216,4 @@ public void setModelOutputName(String modelOutputName) { } - @NestedConfigurationProperty - private final Onnx onnx = new Onnx(); - - public Onnx getOnnx() { - return this.onnx; - } - - /** - * Specifies what parts of the {@link Document}'s content and metadata will be used - * for computing the embeddings. Applicable for the - * {@link TransformersEmbeddingModel#embed(Document)} method only. Has no effect on - * the {@link TransformersEmbeddingModel#embed(String)} or - * {@link TransformersEmbeddingModel#embed(List)}. Defaults to - * {@link MetadataMode#NONE}. - */ - private MetadataMode metadataMode = MetadataMode.NONE; - - public Tokenizer getTokenizer() { - return this.tokenizer; - } - - public MetadataMode getMetadataMode() { - return this.metadataMode; - } - - public void setMetadataMode(MetadataMode metadataMode) { - this.metadataMode = metadataMode; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/CommonVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/CommonVectorStoreProperties.java index 5fd20bb55c9..db3d6e5b924 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/CommonVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/CommonVectorStoreProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore; /** @@ -30,7 +31,7 @@ public class CommonVectorStoreProperties { private boolean initializeSchema = false; public boolean isInitializeSchema() { - return initializeSchema; + return this.initializeSchema; } public void setInitializeSchema(boolean initializeSchema) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java index f38bda9dd75..1e0d6bfeb4a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,14 @@ package org.springframework.ai.autoconfigure.vectorstore.azure; +import java.util.List; + import com.azure.core.credential.AzureKeyCredential; import com.azure.core.util.ClientOptions; import com.azure.search.documents.indexes.SearchIndexClient; import com.azure.search.documents.indexes.SearchIndexClientBuilder; - import io.micrometer.observation.ObservationRegistry; -import java.util.List; - import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreProperties.java index 6de3a57305b..661807c7b14 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.azure; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -38,7 +39,7 @@ public class AzureVectorStoreProperties extends CommonVectorStoreProperties { private double defaultSimilarityThreshold = -1; public String getUrl() { - return url; + return this.url; } public void setUrl(String endpointUrl) { @@ -46,7 +47,7 @@ public void setUrl(String endpointUrl) { } public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -54,7 +55,7 @@ public void setApiKey(String apiKey) { } public String getIndexName() { - return indexName; + return this.indexName; } public void setIndexName(String indexName) { @@ -62,7 +63,7 @@ public void setIndexName(String indexName) { } public int getDefaultTopK() { - return defaultTopK; + return this.defaultTopK; } public void setDefaultTopK(int defaultTopK) { @@ -70,7 +71,7 @@ public void setDefaultTopK(int defaultTopK) { } public double getDefaultSimilarityThreshold() { - return defaultSimilarityThreshold; + return this.defaultSimilarityThreshold; } public void setDefaultSimilarityThreshold(double defaultSimilarityThreshold) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java index 0431133b985..f500e1bdf6a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -20,7 +20,6 @@ import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; - import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.embedding.BatchingStrategy; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java index f84a2f0e4a6..18be88e5ba3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.cassandra; import com.google.api.client.util.Preconditions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaApiProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaApiProperties.java index b9d84ec00a5..4f651278a62 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaApiProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaApiProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.chroma; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -36,7 +37,7 @@ public class ChromaApiProperties { private String password; public String getHost() { - return host; + return this.host; } public void setHost(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaConnectionDetails.java index 465086d34c4..58966d3369a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.chroma; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java index 441cb04add8..0cf7aba5a87 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.vectorstore.chroma; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -31,10 +34,6 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreProperties.java index 1768ed81d92..42d4edca047 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.chroma; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -31,7 +32,7 @@ public class ChromaVectorStoreProperties extends CommonVectorStoreProperties { private String collectionName = ChromaVectorStore.DEFAULT_COLLECTION_NAME; public String getCollectionName() { - return collectionName; + return this.collectionName; } public void setCollectionName(String collectionName) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java index 8fdbb1ad4cd..dd77e5341e1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,7 +16,9 @@ package org.springframework.ai.autoconfigure.vectorstore.cosmosdb; +import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosClientBuilder; +import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,8 +32,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import com.azure.cosmos.CosmosAsyncClient; -import io.micrometer.observation.ObservationRegistry; /** * @author Theo van Kraay diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java index d7d06ac25b8..ac716cbd7d2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java index 3ca7a399bb3..3978f14499a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,7 @@ package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; +import io.micrometer.observation.ObservationRegistry; import org.elasticsearch.client.RestClient; import org.springframework.ai.embedding.BatchingStrategy; @@ -33,8 +34,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; -import io.micrometer.observation.ObservationRegistry; - /** * @author Eddú Meléndez * @author Wei Jiang diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java index ba0100bb3fd..336a8f57ebb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -52,7 +53,7 @@ public void setIndexName(String indexName) { } public Integer getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(Integer dimensions) { @@ -60,7 +61,7 @@ public void setDimensions(Integer dimensions) { } public SimilarityFunction getSimilarity() { - return similarity; + return this.similarity; } public void setSimilarity(SimilarityFunction similarity) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java index 32a3a7486af..2013f9f60a2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.gemfire; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java index f9ed2f8ce6e..32e4fada4c6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.vectorstore.gemfire; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -29,8 +31,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geet Rawat * @author Christian Tzolov diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java index 2e7e89cfe62..5e650dc923a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java @@ -94,7 +94,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { private boolean sslEnabled = GemFireVectorStore.GemFireVectorStoreConfig.DEFAULT_SSL_ENABLED; public int getBeamWidth() { - return beamWidth; + return this.beamWidth; } public void setBeamWidth(int beamWidth) { @@ -102,7 +102,7 @@ public void setBeamWidth(int beamWidth) { } public int getPort() { - return port; + return this.port; } public void setPort(int port) { @@ -110,7 +110,7 @@ public void setPort(int port) { } public String getHost() { - return host; + return this.host; } public void setHost(String host) { @@ -118,7 +118,7 @@ public void setHost(String host) { } public String getIndexName() { - return indexName; + return this.indexName; } public void setIndexName(String indexName) { @@ -126,7 +126,7 @@ public void setIndexName(String indexName) { } public int getMaxConnections() { - return maxConnections; + return this.maxConnections; } public void setMaxConnections(int maxConnections) { @@ -134,7 +134,7 @@ public void setMaxConnections(int maxConnections) { } public String getVectorSimilarityFunction() { - return vectorSimilarityFunction; + return this.vectorSimilarityFunction; } public void setVectorSimilarityFunction(String vectorSimilarityFunction) { @@ -142,7 +142,7 @@ public void setVectorSimilarityFunction(String vectorSimilarityFunction) { } public String[] getFields() { - return fields; + return this.fields; } public void setFields(String[] fields) { @@ -150,7 +150,7 @@ public void setFields(String[] fields) { } public int getBuckets() { - return buckets; + return this.buckets; } public void setBuckets(int buckets) { @@ -158,7 +158,7 @@ public void setBuckets(int buckets) { } public boolean isSslEnabled() { - return sslEnabled; + return this.sslEnabled; } public void setSslEnabled(boolean sslEnabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java index 21cb9b9aa9b..a2dda1e7a6a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.hanadb; import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.HanaCloudVectorStore; import org.springframework.ai.vectorstore.HanaCloudVectorStoreConfig; @@ -31,8 +34,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; - /** * @author Rahul Mittal * @author Christian Tzolov @@ -59,4 +60,4 @@ public HanaCloudVectorStore vectorStore(HanaVectorRepository null)); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreProperties.java index 79dfbfbc1af..4c30d240700 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.hanadb; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -31,7 +32,7 @@ public class HanaCloudVectorStoreProperties { private int topK; public String getTableName() { - return tableName; + return this.tableName; } public void setTableName(String tableName) { @@ -39,7 +40,7 @@ public void setTableName(String tableName) { } public int getTopK() { - return topK; + return this.topK; } public void setTopK(int topK) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientConnectionDetails.java index b6d015630b5..ffc93daed75 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.milvus; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientProperties.java index ecd3ebd516b..9d677aa2ffa 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.milvus; import java.util.concurrent.TimeUnit; @@ -29,6 +30,11 @@ public class MilvusServiceClientProperties { public static final String CONFIG_PREFIX = "spring.ai.vectorstore.milvus.client"; + /** + * Secure the authorization for this connection, set to True to enable TLS. + */ + protected boolean secure = false; + /** * Milvus host name/address. */ @@ -62,15 +68,15 @@ public class MilvusServiceClientProperties { private long keepAliveTimeMs = 55000; /** - * The keep-alive timeout value of client channel. The timeout value must be greater - * than zero. + * Enables the keep-alive function for client channel. */ - private long keepAliveTimeoutMs = 20000; + // private boolean keepAliveWithoutCalls = false; /** - * Enables the keep-alive function for client channel. + * The keep-alive timeout value of client channel. The timeout value must be greater + * than zero. */ - // private boolean keepAliveWithoutCalls = false; + private long keepAliveTimeoutMs = 20000; /** * Deadline for how long you are willing to wait for a reply from the server. With a @@ -110,11 +116,6 @@ public class MilvusServiceClientProperties { */ private String serverName; - /** - * Secure the authorization for this connection, set to True to enable TLS. - */ - protected boolean secure = false; - /** * Idle timeout value of client channel. The timeout value must be larger than zero. */ @@ -131,7 +132,7 @@ public class MilvusServiceClientProperties { private String password = "milvus"; public String getHost() { - return host; + return this.host; } public void setHost(String host) { @@ -139,7 +140,7 @@ public void setHost(String host) { } public int getPort() { - return port; + return this.port; } public void setPort(int port) { @@ -147,7 +148,7 @@ public void setPort(int port) { } public String getUri() { - return uri; + return this.uri; } public void setUri(String uri) { @@ -155,7 +156,7 @@ public void setUri(String uri) { } public String getToken() { - return token; + return this.token; } public void setToken(String token) { @@ -163,7 +164,7 @@ public void setToken(String token) { } public long getConnectTimeoutMs() { - return connectTimeoutMs; + return this.connectTimeoutMs; } public void setConnectTimeoutMs(long connectTimeoutMs) { @@ -171,7 +172,7 @@ public void setConnectTimeoutMs(long connectTimeoutMs) { } public long getKeepAliveTimeMs() { - return keepAliveTimeMs; + return this.keepAliveTimeMs; } public void setKeepAliveTimeMs(long keepAliveTimeMs) { @@ -179,7 +180,7 @@ public void setKeepAliveTimeMs(long keepAliveTimeMs) { } public long getKeepAliveTimeoutMs() { - return keepAliveTimeoutMs; + return this.keepAliveTimeoutMs; } public void setKeepAliveTimeoutMs(long keepAliveTimeoutMs) { @@ -195,7 +196,7 @@ public void setKeepAliveTimeoutMs(long keepAliveTimeoutMs) { // } public long getRpcDeadlineMs() { - return rpcDeadlineMs; + return this.rpcDeadlineMs; } public void setRpcDeadlineMs(long rpcDeadlineMs) { @@ -203,7 +204,7 @@ public void setRpcDeadlineMs(long rpcDeadlineMs) { } public String getClientKeyPath() { - return clientKeyPath; + return this.clientKeyPath; } public void setClientKeyPath(String clientKeyPath) { @@ -211,7 +212,7 @@ public void setClientKeyPath(String clientKeyPath) { } public String getClientPemPath() { - return clientPemPath; + return this.clientPemPath; } public void setClientPemPath(String clientPemPath) { @@ -219,7 +220,7 @@ public void setClientPemPath(String clientPemPath) { } public String getCaPemPath() { - return caPemPath; + return this.caPemPath; } public void setCaPemPath(String caPemPath) { @@ -227,7 +228,7 @@ public void setCaPemPath(String caPemPath) { } public String getServerPemPath() { - return serverPemPath; + return this.serverPemPath; } public void setServerPemPath(String serverPemPath) { @@ -235,7 +236,7 @@ public void setServerPemPath(String serverPemPath) { } public String getServerName() { - return serverName; + return this.serverName; } public void setServerName(String serverName) { @@ -243,7 +244,7 @@ public void setServerName(String serverName) { } public boolean isSecure() { - return secure; + return this.secure; } public void setSecure(boolean secure) { @@ -251,7 +252,7 @@ public void setSecure(boolean secure) { } public long getIdleTimeoutMs() { - return idleTimeoutMs; + return this.idleTimeoutMs; } public void setIdleTimeoutMs(long idleTimeoutMs) { @@ -259,7 +260,7 @@ public void setIdleTimeoutMs(long idleTimeoutMs) { } public String getUsername() { - return username; + return this.username; } public void setUsername(String username) { @@ -267,7 +268,7 @@ public void setUsername(String username) { } public String getPassword() { - return password; + return this.password; } public void setPassword(String password) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java index ec789bed87f..e22d53b9950 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.milvus; +import java.util.concurrent.TimeUnit; + import io.micrometer.observation.ObservationRegistry; import io.milvus.client.MilvusServiceClient; import io.milvus.param.ConnectParam; import io.milvus.param.IndexType; import io.milvus.param.MetricType; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -34,8 +38,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; -import java.util.concurrent.TimeUnit; - /** * @author Christian Tzolov * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java index 2a4b828641f..9a17543b5db 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.milvus; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -58,44 +59,8 @@ public class MilvusVectorStoreProperties extends CommonVectorStoreProperties { */ private String indexParameters = "{\"nlist\":1024}"; - public enum MilvusMetricType { - - /** - * Invalid metric type - */ - INVALID, - /** - * Euclidean distance - */ - L2, - /** - * Inner product - */ - IP, - /** - * Cosine distance - */ - COSINE, - /** - * Hamming distance - */ - HAMMING, - /** - * Jaccard distance - */ - JACCARD; - - } - - public enum MilvusIndexType { - - INVALID, FLAT, IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW, DISKANN, AUTOINDEX, SCANN, GPU_IVF_FLAT, GPU_IVF_PQ, BIN_FLAT, - BIN_IVF_FLAT, TRIE, STL_SORT; - - } - public String getDatabaseName() { - return databaseName; + return this.databaseName; } public void setDatabaseName(String databaseName) { @@ -104,7 +69,7 @@ public void setDatabaseName(String databaseName) { } public String getCollectionName() { - return collectionName; + return this.collectionName; } public void setCollectionName(String collectionName) { @@ -113,7 +78,7 @@ public void setCollectionName(String collectionName) { } public int getEmbeddingDimension() { - return embeddingDimension; + return this.embeddingDimension; } public void setEmbeddingDimension(int embeddingDimension) { @@ -122,7 +87,7 @@ public void setEmbeddingDimension(int embeddingDimension) { } public MilvusIndexType getIndexType() { - return indexType; + return this.indexType; } public void setIndexType(MilvusIndexType indexType) { @@ -131,7 +96,7 @@ public void setIndexType(MilvusIndexType indexType) { } public MilvusMetricType getMetricType() { - return metricType; + return this.metricType; } public void setMetricType(MilvusMetricType metricType) { @@ -140,7 +105,7 @@ public void setMetricType(MilvusMetricType metricType) { } public String getIndexParameters() { - return indexParameters; + return this.indexParameters; } public void setIndexParameters(String indexParameters) { @@ -148,4 +113,40 @@ public void setIndexParameters(String indexParameters) { this.indexParameters = indexParameters; } + public enum MilvusMetricType { + + /** + * Invalid metric type + */ + INVALID, + /** + * Euclidean distance + */ + L2, + /** + * Inner product + */ + IP, + /** + * Cosine distance + */ + COSINE, + /** + * Hamming distance + */ + HAMMING, + /** + * Jaccard distance + */ + JACCARD; + + } + + public enum MilvusIndexType { + + INVALID, FLAT, IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW, DISKANN, AUTOINDEX, SCANN, GPU_IVF_FLAT, GPU_IVF_PQ, BIN_FLAT, + BIN_IVF_FLAT, TRIE, STL_SORT; + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java index f9053d0bb3b..59f5855a7c0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,10 @@ package org.springframework.ai.autoconfigure.vectorstore.mongo; +import java.util.Arrays; + +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -33,10 +37,6 @@ import org.springframework.util.MimeType; import org.springframework.util.StringUtils; -import io.micrometer.observation.ObservationRegistry; - -import java.util.Arrays; - /** * @author Eddú Meléndez * @author Christian Tzolov @@ -86,6 +86,7 @@ MongoDBAtlasVectorStore vectorStore(MongoTemplate mongoTemplate, EmbeddingModel @Bean public Converter mimeTypeToStringConverter() { return new Converter() { + @Override public String convert(MimeType source) { return source.toString(); @@ -96,6 +97,7 @@ public String convert(MimeType source) { @Bean public Converter stringToMimeTypeConverter() { return new Converter() { + @Override public MimeType convert(String source) { return MimeType.valueOf(source); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreProperties.java index 683337464a0..22b6c680e3b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.mongo; +import java.util.List; + import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; import org.springframework.boot.context.properties.ConfigurationProperties; -import java.util.List; - /** * @author Eddú Meléndez * @author Christian Tzolov diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java index 3faaa3b644a..e5be310fe57 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,7 @@ package org.springframework.ai.autoconfigure.vectorstore.neo4j; +import io.micrometer.observation.ObservationRegistry; import org.neo4j.driver.Driver; import org.springframework.ai.embedding.BatchingStrategy; @@ -31,8 +32,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; - /** * @author Jingzhou Ou * @author Josh Long diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java index 5d782b8a9bb..53e75b46a9c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.neo4j; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfiguration.java index c37c22389b1..b69a48e1beb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.observation; import io.micrometer.tracing.otel.bridge.OtelTracer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationFilter; import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationHandler; @@ -46,6 +48,11 @@ public class VectorStoreObservationAutoConfiguration { private static final Logger logger = LoggerFactory.getLogger(VectorStoreObservationAutoConfiguration.class); + private static void logQueryResponseContentWarning() { + logger.warn( + "You have enabled the inclusion of the query response content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); + } + /** * The query response content is typically too big to be included in an observation as * span attributes. That's why the preferred way to store it is as span events, which @@ -84,9 +91,4 @@ VectorStoreQueryResponseObservationFilter vectorStoreQueryResponseContentObserva } - private static void logQueryResponseContentWarning() { - logger.warn( - "You have enabled the inclusion of the query response content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationProperties.java index 33589abad26..423ca571934 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.observation; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/package-info.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/package-info.java index 347dd6a3ef9..af2a6feec4f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/package-info.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchConnectionDetails.java index 39c3b4c34a5..1b6b3d207cc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.opensearch; -import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; +package org.springframework.ai.autoconfigure.vectorstore.opensearch; import java.util.List; +import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; + public interface OpenSearchConnectionDetails extends ConnectionDetails { List getUris(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java index fe19c255205..78fc694a7b4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,11 @@ package org.springframework.ai.autoconfigure.vectorstore.opensearch; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Optional; + +import io.micrometer.observation.ObservationRegistry; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; @@ -25,6 +30,11 @@ import org.opensearch.client.transport.aws.AwsSdk2Transport; import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.regions.Region; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -40,17 +50,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import io.micrometer.observation.ObservationRegistry; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; - -import java.net.URISyntaxException; -import java.util.List; -import java.util.Optional; - @AutoConfiguration @ConditionalOnClass({ OpenSearchVectorStore.class, EmbeddingModel.class, OpenSearchClient.class }) @EnableConfigurationProperties(OpenSearchVectorStoreProperties.class) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java index a50c02ef655..a8b9f4e7e49 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.opensearch; +import java.util.List; + import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; import org.springframework.boot.context.properties.ConfigurationProperties; -import java.util.List; - @ConfigurationProperties(prefix = OpenSearchVectorStoreProperties.CONFIG_PREFIX) public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties { @@ -41,7 +42,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties private Aws aws = new Aws(); public List getUris() { - return uris; + return this.uris; } public void setUris(List uris) { @@ -57,7 +58,7 @@ public void setIndexName(String indexName) { } public String getUsername() { - return username; + return this.username; } public void setUsername(String username) { @@ -65,7 +66,7 @@ public void setUsername(String username) { } public String getPassword() { - return password; + return this.password; } public void setPassword(String password) { @@ -73,7 +74,7 @@ public void setPassword(String password) { } public String getMappingJson() { - return mappingJson; + return this.mappingJson; } public void setMappingJson(String mappingJson) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java index 837e278c2b8..a63c95bac4e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +18,8 @@ import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -32,8 +34,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.jdbc.core.JdbcTemplate; -import io.micrometer.observation.ObservationRegistry; - /** * @author Loïc Lefèvre * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java index 2d5eb2f7f4e..27fd396c4b1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.oracle; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -44,7 +45,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { private int searchAccuracy = DEFAULT_SEARCH_ACCURACY; public String getTableName() { - return tableName; + return this.tableName; } public void setTableName(String tableName) { @@ -52,7 +53,7 @@ public void setTableName(String tableName) { } public OracleVectorStore.OracleVectorStoreIndexType getIndexType() { - return indexType; + return this.indexType; } public void setIndexType(OracleVectorStore.OracleVectorStoreIndexType indexType) { @@ -60,7 +61,7 @@ public void setIndexType(OracleVectorStore.OracleVectorStoreIndexType indexType) } public OracleVectorStore.OracleVectorStoreDistanceType getDistanceType() { - return distanceType; + return this.distanceType; } public void setDistanceType(OracleVectorStore.OracleVectorStoreDistanceType distanceType) { @@ -68,7 +69,7 @@ public void setDistanceType(OracleVectorStore.OracleVectorStoreDistanceType dist } public int getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(int dimensions) { @@ -76,7 +77,7 @@ public void setDimensions(int dimensions) { } public boolean isRemoveExistingVectorStoreTable() { - return removeExistingVectorStoreTable; + return this.removeExistingVectorStoreTable; } public void setRemoveExistingVectorStoreTable(boolean removeExistingVectorStoreTable) { @@ -84,7 +85,7 @@ public void setRemoveExistingVectorStoreTable(boolean removeExistingVectorStoreT } public boolean isForcedNormalization() { - return forcedNormalization; + return this.forcedNormalization; } public void setForcedNormalization(boolean forcedNormalization) { @@ -92,7 +93,7 @@ public void setForcedNormalization(boolean forcedNormalization) { } public int getSearchAccuracy() { - return searchAccuracy; + return this.searchAccuracy; } public void setSearchAccuracy(int searchAccuracy) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index ec4d76e0748..8f9cf21b454 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +18,8 @@ import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -32,8 +34,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.jdbc.core.JdbcTemplate; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Josh Long diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java index 47a12c36d3e..d2947a5adc1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.pgvector; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -49,7 +50,7 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE; public int getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(int dimensions) { @@ -57,7 +58,7 @@ public void setDimensions(int dimensions) { } public PgIndexType getIndexType() { - return indexType; + return this.indexType; } public void setIndexType(PgIndexType createIndexMethod) { @@ -65,7 +66,7 @@ public void setIndexType(PgIndexType createIndexMethod) { } public PgDistanceType getDistanceType() { - return distanceType; + return this.distanceType; } public void setDistanceType(PgDistanceType distanceType) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java index 058b62841aa..9526c3e3c84 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.vectorstore.pinecone; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -29,8 +31,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Soby Chacko diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java index 3ba28c228c5..c73dfee077f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.pinecone; import java.time.Duration; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantConnectionDetails.java index e5cf97fb7df..321d21b8eef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java index d1cb2379070..3d1914200cc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreProperties.java index 880f6925da3..10438c5a953 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java index 92631831b60..2ffaf86a287 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.vectorstore.redis; +import io.micrometer.observation.ObservationRegistry; +import redis.clients.jedis.JedisPooled; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -32,9 +35,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; -import io.micrometer.observation.ObservationRegistry; -import redis.clients.jedis.JedisPooled; - /** * @author Christian Tzolov * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreProperties.java index 4799afb8021..9a192260775 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.redis; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseConnectionDetails.java index 48d6b6cc370..94f0fd102aa 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.typesense; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseServiceClientProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseServiceClientProperties.java index 72e6f6a9fd1..bc4ab0dae91 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseServiceClientProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseServiceClientProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -39,7 +39,7 @@ public class TypesenseServiceClientProperties { private String apiKey = "xyz"; public String getProtocol() { - return protocol; + return this.protocol; } public void setProtocol(String protocol) { @@ -47,7 +47,7 @@ public void setProtocol(String protocol) { } public String getHost() { - return host; + return this.host; } public void setHost(String host) { @@ -55,7 +55,7 @@ public void setHost(String host) { } public int getPort() { - return port; + return this.port; } public void setPort(int port) { @@ -63,7 +63,7 @@ public void setPort(int port) { } public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java index de6e9a49033..14789133d4d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,15 @@ package org.springframework.ai.autoconfigure.vectorstore.typesense; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +import io.micrometer.observation.ObservationRegistry; +import org.typesense.api.Client; +import org.typesense.api.Configuration; +import org.typesense.resources.Node; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -28,15 +37,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import org.typesense.api.Client; -import org.typesense.api.Configuration; -import org.typesense.resources.Node; - -import io.micrometer.observation.ObservationRegistry; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; /** * @author Pablo Sanchidrian Herrera diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreProperties.java index ddf74de4d28..22eea6396eb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -40,7 +40,7 @@ public class TypesenseVectorStoreProperties extends CommonVectorStoreProperties private int embeddingDimension = TypesenseVectorStore.OPENAI_EMBEDDING_DIMENSION_SIZE; public String getCollectionName() { - return collectionName; + return this.collectionName; } public void setCollectionName(String collectionName) { @@ -48,7 +48,7 @@ public void setCollectionName(String collectionName) { } public int getEmbeddingDimension() { - return embeddingDimension; + return this.embeddingDimension; } public void setEmbeddingDimension(int embeddingDimension) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateConnectionDetails.java index 154271c6f14..5981040d3ad 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.weaviate; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java index 9ca3899db74..f16226bd762 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java index d0793de4881..7b3a61f16b2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.weaviate; import java.util.Map; -import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; @@ -48,24 +48,24 @@ public class WeaviateVectorStoreProperties { private Map headers = Map.of(); + public String getScheme() { + return this.scheme; + } + public void setScheme(String scheme) { this.scheme = scheme; } - public String getScheme() { - return scheme; + public String getHost() { + return this.host; } public void setHost(String host) { this.host = host; } - public String getHost() { - return host; - } - public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -73,7 +73,7 @@ public void setApiKey(String apiKey) { } public String getObjectClass() { - return objectClass; + return this.objectClass; } public void setObjectClass(String indexName) { @@ -81,7 +81,7 @@ public void setObjectClass(String indexName) { } public ConsistentLevel getConsistencyLevel() { - return consistencyLevel; + return this.consistencyLevel; } public void setConsistencyLevel(ConsistentLevel consistencyLevel) { @@ -89,7 +89,7 @@ public void setConsistencyLevel(ConsistentLevel consistencyLevel) { } public Map getHeaders() { - return headers; + return this.headers; } public void setHeaders(Map headers) { @@ -97,7 +97,7 @@ public void setHeaders(Map headers) { } public Map getFilterField() { - return filterField; + return this.filterField; } public void setFilterField(Map filterMetadataFields) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java index b51c0e718cc..7549cb727e2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.embedding; import java.io.IOException; +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; @@ -34,10 +38,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import com.google.cloud.vertexai.VertexAI; - -import io.micrometer.observation.ObservationRegistry; - /** * Auto-configuration for Vertex AI Gemini Chat. * diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java index 0073f569046..a86462d396d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.embedding; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java index 6d08403f56e..a47488b92fe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.embedding; import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java index 102548521d0..26073283f07 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.embedding; import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java index a6c83100a89..b73332e8170 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini; import java.io.IOException; import java.util.List; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.FunctionCallback; @@ -38,11 +43,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import com.google.auth.oauth2.GoogleCredentials; -import com.google.cloud.vertexai.VertexAI; - -import io.micrometer.observation.ObservationRegistry; - /** * Auto-configuration for Vertex AI Gemini Chat. * diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java index 4e9572c1531..22d25ab2488 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiConnectionProperties.java index ef65327b531..e47d41863f1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini; import java.util.List; @@ -58,18 +59,6 @@ public class VertexAiGeminiConnectionProperties { private Transport transport = Transport.GRPC; - public enum Transport { - - /** When used, the clients will send REST requests to the backing service. */ - REST, - /** - * When used, the clients will send gRPC to the backing service. This is usually - * more efficient and is the default transport. - */ - GRPC - - } - public String getProjectId() { return this.projectId; } @@ -98,6 +87,10 @@ public String getApiEndpoint() { return this.apiEndpoint; } + public void setApiEndpoint(String apiEndpoint) { + this.apiEndpoint = apiEndpoint; + } + public List getScopes() { return this.scopes; } @@ -106,10 +99,6 @@ public void setScopes(List scopes) { this.scopes = scopes; } - public void setApiEndpoint(String apiEndpoint) { - this.apiEndpoint = apiEndpoint; - } - public Transport getTransport() { return this.transport; } @@ -118,4 +107,16 @@ public void setTransport(Transport transport) { this.transport = transport; } + public enum Transport { + + /** When used, the clients will send REST requests to the backing service. */ + REST, + /** + * When used, the clients will send gRPC to the backing service. This is usually + * more efficient and is the default transport. + */ + GRPC + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2AutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2AutoConfiguration.java index e2ac02ed66a..e782503895b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2AutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2AutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.palm2; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2ConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2ConnectionProperties.java index 6c4ae48fbdc..49e93b8dcb4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2ConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2ConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.palm2; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2EmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2EmbeddingProperties.java index 0dc079b03bf..531c8d4fb5f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2EmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2EmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.palm2; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPlam2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPlam2ChatProperties.java index 417f9a89b1a..5627defd6d4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPlam2ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPlam2ChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.palm2; import org.springframework.ai.vertexai.palm2.VertexAiPaLm2ChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java index 4d7f9f43637..e5d53314bb6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.watsonxai; import org.springframework.ai.watsonx.WatsonxAiChatModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java index 3f9dc8fe9d1..80b9abbf5e6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.watsonxai; +import java.util.List; + import org.springframework.ai.watsonx.WatsonxAiChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; -import java.util.List; - /** * Chat properties for Watsonx.AI Chat. * diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java index 0ffc3656d18..5e4fecf2133 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.watsonxai; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -42,7 +43,7 @@ public class WatsonxAiConnectionProperties { private String IAMToken; public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { @@ -50,7 +51,7 @@ public void setBaseUrl(String baseUrl) { } public String getStreamEndpoint() { - return streamEndpoint; + return this.streamEndpoint; } public void setStreamEndpoint(String streamEndpoint) { @@ -58,7 +59,7 @@ public void setStreamEndpoint(String streamEndpoint) { } public String getTextEndpoint() { - return textEndpoint; + return this.textEndpoint; } public void setTextEndpoint(String textEndpoint) { @@ -66,7 +67,7 @@ public void setTextEndpoint(String textEndpoint) { } public String getEmbeddingEndpoint() { - return embeddingEndpoint; + return this.embeddingEndpoint; } public void setEmbeddingEndpoint(String embeddingEndpoint) { @@ -74,7 +75,7 @@ public void setEmbeddingEndpoint(String embeddingEndpoint) { } public String getProjectId() { - return projectId; + return this.projectId; } public void setProjectId(String projectId) { @@ -82,7 +83,7 @@ public void setProjectId(String projectId) { } public String getIAMToken() { - return IAMToken; + return this.IAMToken; } public void setIAMToken(String IAMToken) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java index 42425a265a7..983291d21ed 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.watsonxai; import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; @@ -40,12 +56,12 @@ public WatsonxAiEmbeddingOptions getOptions() { return this.options; } - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public boolean isEnabled() { return this.enabled; } + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java index 98afeaf79fe..7b89cad7733 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -42,8 +45,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geng Rong */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java index d1179e99009..86004fec967 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; @@ -44,7 +45,7 @@ public class ZhiPuAiChatProperties extends ZhiPuAiParentProperties { .build(); public ZhiPuAiChatOptions getOptions() { - return options; + return this.options; } public void setOptions(ZhiPuAiChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java index 6d850f3d75f..798afd43a97 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java index 4e1c6ef80d0..86b8b008695 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.springframework.ai.document.MetadataMode; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java index 7463d457397..4751dbe6f8d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.springframework.ai.zhipuai.ZhiPuAiImageOptions; @@ -39,7 +40,7 @@ public class ZhiPuAiImageProperties extends ZhiPuAiParentProperties { private ZhiPuAiImageOptions options = ZhiPuAiImageOptions.builder().build(); public ZhiPuAiImageOptions getOptions() { - return options; + return this.options; } public void setOptions(ZhiPuAiImageOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java index 70d43d77092..c89102ec103 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; /** @@ -25,7 +26,7 @@ class ZhiPuAiParentProperties { private String baseUrl; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -33,7 +34,7 @@ public void setApiKey(String apiKey) { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index d5b849aa3c8..ef4dd4b511c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration org.springframework.ai.autoconfigure.oci.genai.OCIGenAiAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfigurationIT.java index f35b324c98f..d550d32b392 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.anthropic; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.anthropic; import java.util.List; import java.util.stream.Collectors; @@ -24,6 +23,8 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -35,7 +36,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".*") public class AnthropicAutoConfigurationIT { @@ -48,7 +49,7 @@ public class AnthropicAutoConfigurationIT { @Test void call() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AnthropicChatModel chatModel = context.getBean(AnthropicChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -58,7 +59,7 @@ void call() { @Test void callWith8KResponseContext() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.anthropic.beta-version=" + AnthropicApi.BETA_MAX_TOKENS, "spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue()) .run(context -> { @@ -72,7 +73,7 @@ void callWith8KResponseContext() { @Test void stream() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AnthropicChatModel chatModel = context.getBean(AnthropicChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicPropertiesTests.java index ca9cca03f51..e2909971dd8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java index 3a0a80052af..0c4e933c1c1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.anthropic.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.anthropic.tool; import java.util.List; import java.util.function.Function; @@ -24,6 +23,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -40,6 +40,8 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".*") class FunctionCallWithFunctionBeanIT { @@ -53,7 +55,7 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues( "spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue()) .run(context -> { @@ -66,14 +68,14 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AnthropicChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), AnthropicChatOptions.builder().withFunction("weatherFunction3").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -83,7 +85,7 @@ void functionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner + this.contextRunner .withPropertyValues( "spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue()) .run(context -> { @@ -96,7 +98,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -121,4 +123,4 @@ public Function weather } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java index 9f3cf79c3e3..9dccd4c50c5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.anthropic.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.anthropic.tool; import java.util.List; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -34,6 +34,8 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".*") public class FunctionCallWithPromptFunctionIT { @@ -45,7 +47,7 @@ public class FunctionCallWithPromptFunctionIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues( "spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue()) .run(context -> { @@ -64,10 +66,10 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java index 752ddb2d71d..e27e66300ee 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic.tool; import java.util.function.Function; @@ -30,14 +31,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -65,28 +73,23 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - - @Override - public Response apply(Request request) { - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 0f4e7fe4930..c50ce1d4a58 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,25 @@ package org.springframework.ai.autoconfigure.azure; +import java.lang.reflect.Field; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.implementation.OpenAIClientImpl; -import com.azure.core.http.*; +import com.azure.core.http.HttpHeader; +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpMethod; +import com.azure.core.http.HttpPipeline; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; @@ -39,13 +52,6 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.ReflectionUtils; -import reactor.core.publisher.Flux; - -import java.lang.reflect.Field; -import java.net.URI; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -89,16 +95,16 @@ class AzureOpenAiAutoConfigurationIT { @Test void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test void httpRequestContainsUserAgentAndCustomHeaders() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar", "spring.ai.azure.openai.custom-headers.fizz=buzz") .run(context -> { @@ -125,11 +131,11 @@ void httpRequestContainsUserAgentAndCustomHeaders() { @Test void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); - Flux response = chatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatModel.stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(10); @@ -147,7 +153,7 @@ void chatCompletionStreaming() { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiEmbeddingModel embeddingModel = context.getBean(AzureOpenAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel @@ -165,7 +171,7 @@ void embedding() { @Test @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_TRANSCRIPTION_DEPLOYMENT_NAME", matches = ".+") void transcribe() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiAudioTranscriptionModel transcriptionModel = context .getBean(AzureOpenAiAudioTranscriptionModel.class); Resource audioFile = new ClassPathResource("/speech/jfk.flac"); @@ -179,17 +185,17 @@ void transcribe() { void chatActivation() { // Disable the chat auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=false").run(context -> { assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isEmpty(); }); // The chat auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty(); }); // Explicitly enable the chat auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=true").run(context -> { assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty(); }); } @@ -198,17 +204,17 @@ void chatActivation() { void embeddingActivation() { // Disable the embedding auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=false").run(context -> { assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isEmpty(); }); // The embedding auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty(); }); // Explicitly enable the embedding auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=true").run(context -> { assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty(); }); } @@ -217,19 +223,21 @@ void embeddingActivation() { void audioTranscriptionActivation() { // Disable the transcription auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=false").run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=false") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty(); + }); // The transcription auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty(); }); // Explicitly enable the transcription auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=true").run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=true") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty(); + }); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java index 581f178c046..e83c75e2327 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiDirectOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiDirectOpenAiAutoConfigurationIT.java index f042706aa97..19c651b3efe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiDirectOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiDirectOpenAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure; import java.util.List; @@ -21,19 +22,19 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.azure.openai.AzureOpenAiChatModel; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; +import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -73,20 +74,20 @@ public class AzureOpenAiDirectOpenAiAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); - Flux response = chatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatModel.stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(10); @@ -104,7 +105,7 @@ public void chatCompletionStreaming() { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiEmbeddingModel embeddingModel = context.getBean(AzureOpenAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java index dafd8f49a2a..fa2c77b1f5b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.azure.tool; import org.springframework.util.StringUtils; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java index a7e04c3515a..dda06c82429 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.azure.tool; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName; +package org.springframework.ai.autoconfigure.azure.tool; import java.util.List; import java.util.function.Function; @@ -25,6 +23,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; @@ -39,6 +38,9 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName; + @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") class FunctionCallWithFunctionBeanIT { @@ -55,7 +57,8 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) + this.contextRunner + .withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) .run(context -> { ChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -66,14 +69,14 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("weatherFunction3").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -82,7 +85,8 @@ void functionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) + this.contextRunner + .withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) .run(context -> { ChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -93,7 +97,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -119,4 +123,4 @@ public Function weather } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java index 9a54bef4ea8..62071ed4168 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.tool; import java.util.List; @@ -25,8 +26,8 @@ import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackWrapper; @@ -54,7 +55,8 @@ public class FunctionCallWithFunctionWrapperIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) + this.contextRunner + .withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) .run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -65,7 +67,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30", "10", "15"); @@ -86,4 +88,4 @@ public FunctionCallback weatherFunctionInfo() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java index 4c2b622ecb4..00a9145354f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.tool; import java.util.List; @@ -25,8 +26,8 @@ import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -50,7 +51,8 @@ public class FunctionCallWithPromptFunctionIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) + this.contextRunner + .withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) .run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -67,10 +69,10 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java index 9333522be2c..0d390e57ef0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.tool; import java.util.function.Function; @@ -30,15 +31,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @Override + public Response apply(Request request) { - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -66,28 +73,24 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - - @Override - public Response apply(Request request) { - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java index bea58ce80e3..e3d9b2ff6eb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.bedrock; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.bedrock; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -27,10 +31,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Wei Jiang diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java index 4137e33ce62..cc0824cfada 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic; import java.util.List; @@ -21,19 +22,19 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatModel; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatModel; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -68,20 +69,21 @@ public class BedrockAnthropicChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAnthropicChatModel anthropicChatModel = context.getBean(BedrockAnthropicChatModel.class); - ChatResponse response = anthropicChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = anthropicChatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAnthropicChatModel anthropicChatModel = context.getBean(BedrockAnthropicChatModel.class); - Flux response = anthropicChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = anthropicChatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java index 3defe79b3b4..8475517bd97 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic3; import java.util.List; @@ -21,19 +22,19 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -68,20 +69,21 @@ public class BedrockAnthropic3ChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAnthropic3ChatModel anthropicChatModel = context.getBean(BedrockAnthropic3ChatModel.class); - ChatResponse response = anthropicChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = anthropicChatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAnthropic3ChatModel anthropicChatModel = context.getBean(BedrockAnthropic3ChatModel.class); - Flux response = anthropicChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = anthropicChatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java index 83b487c9018..bf748291571 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import java.util.List; @@ -21,21 +22,21 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -71,20 +72,21 @@ public class BedrockCohereChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockCohereChatModel cohereChatModel = context.getBean(BedrockCohereChatModel.class); - ChatResponse response = cohereChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = cohereChatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockCohereChatModel cohereChatModel = context.getBean(BedrockCohereChatModel.class); - Flux response = cohereChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = cohereChatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java index 14d38895516..4523a19553d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingModel; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingModel; @@ -25,9 +30,6 @@ import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import software.amazon.awssdk.regions.Region; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -51,7 +53,7 @@ public class BedrockCohereEmbeddingAutoConfigurationIT { @Test public void singleEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockCohereEmbeddingModel embeddingModel = context.getBean(BedrockCohereEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); @@ -63,7 +65,7 @@ public void singleEmbedding() { @Test public void batchEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockCohereEmbeddingModel embeddingModel = context.getBean(BedrockCohereEmbeddingModel.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java index ace30a03d17..2c2af6bda83 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,24 +16,25 @@ package org.springframework.ai.autoconfigure.bedrock.jurassic2; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatAutoConfiguration; import org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatProperties; import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import software.amazon.awssdk.regions.Region; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -68,9 +69,10 @@ public class BedrockAi21Jurassic2ChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAi21Jurassic2ChatModel ai21Jurassic2ChatModel = context.getBean(BedrockAi21Jurassic2ChatModel.class); - ChatResponse response = ai21Jurassic2ChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = ai21Jurassic2ChatModel + .call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java index f1ed73b8b11..6c4fecc11fc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.llama; import java.util.List; @@ -21,19 +22,19 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -70,20 +71,21 @@ public class BedrockLlamaChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockLlamaChatModel llamaChatModel = context.getBean(BedrockLlamaChatModel.class); - ChatResponse response = llamaChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = llamaChatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockLlamaChatModel llamaChatModel = context.getBean(BedrockLlamaChatModel.class); - Flux response = llamaChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = llamaChatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java index 94a2fda1b68..78749d87324 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import java.util.List; @@ -21,19 +22,19 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.titan.BedrockTitanChatModel; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -69,20 +70,20 @@ public class BedrockTitanChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockTitanChatModel chatModel = context.getBean(BedrockTitanChatModel.class); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockTitanChatModel chatModel = context.getBean(BedrockTitanChatModel.class); - Flux response = chatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatModel.stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java index 5a5a2ad4c19..525898ac065 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import java.util.Base64; @@ -51,7 +52,7 @@ public class BedrockTitanEmbeddingAutoConfigurationIT { @Test public void singleTextEmbedding() { - contextRunner.withPropertyValues("spring.ai.bedrock.titan.embedding.inputType=TEXT").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.bedrock.titan.embedding.inputType=TEXT").run(context -> { BedrockTitanEmbeddingModel embeddingModel = context.getBean(BedrockTitanEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); @@ -63,7 +64,7 @@ public void singleTextEmbedding() { @Test public void singleImageEmbedding() { - contextRunner.withPropertyValues("spring.ai.bedrock.titan.embedding.inputType=IMAGE").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.bedrock.titan.embedding.inputType=IMAGE").run(context -> { BedrockTitanEmbeddingModel embeddingModel = context.getBean(BedrockTitanEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java index 159e2cb1d71..4dd65c789bc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.client; import java.util.List; @@ -50,28 +51,28 @@ public class ChatClientAutoConfigurationIT { @Test void implicitlyEnabled() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(ChatClient.Builder.class)).isNotEmpty(); }); } @Test void explicitlyEnabled() { - contextRunner.withPropertyValues("spring.ai.chat.client.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.client.enabled=true").run(context -> { assertThat(context.getBeansOfType(ChatClient.Builder.class)).isNotEmpty(); }); } @Test void explicitlyDisabled() { - contextRunner.withPropertyValues("spring.ai.chat.client.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.client.enabled=false").run(context -> { assertThat(context.getBeansOfType(ChatClient.Builder.class)).isEmpty(); }); } @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ChatClient.Builder builder = context.getBean(ChatClient.Builder.class); assertThat(builder).isNotNull(); @@ -87,7 +88,7 @@ void generate() { @Test void testChatClientCustomizers() { - contextRunner.withUserConfiguration(Config.class).run(context -> { + this.contextRunner.withUserConfiguration(Config.class).run(context -> { ChatClient.Builder builder = context.getBean(ChatClient.Builder.class); @@ -107,6 +108,7 @@ void testChatClientCustomizers() { } record ActorsFilms(String actor, List movies) { + } @Configuration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java index 94a7ec0db88..658a758b05c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.chat.client; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.chat.client; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.client.observation.ChatClientInputContentObservationFilter; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit tests for {@link ChatClientAutoConfiguration} observability support. * @@ -34,14 +36,14 @@ class ChatClientObservationAutoConfigurationTests { @Test void inputContentFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatClientInputContentObservationFilter.class); }); } @Test void inputContentFilterEnabled() { - contextRunner.withPropertyValues("spring.ai.chat.client.observations.include-input=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.client.observations.include-input=true").run(context -> { assertThat(context).hasSingleBean(ChatClientInputContentObservationFilter.class); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java index 86df7f40cb3..abc4b6c4317 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.chat.memory.cassandra; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.chat.memory.cassandra; import java.util.List; +import com.datastax.driver.core.utils.UUIDs; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.chat.memory.CassandraChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; @@ -27,12 +32,8 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.containers.CassandraContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; -import com.datastax.driver.core.utils.UUIDs; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Mick Semb Wever @@ -53,7 +54,7 @@ class CassandraChatMemoryAutoConfigurationIT { @Test void addAndGet() { - contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) + this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) .withPropertyValues("spring.cassandra.port=" + getContactPointPort()) .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter()) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java index e3df66722e2..c3b47308968 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory.cassandra; import java.time.Duration; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java index 7e4f01a6943..cafd64873fc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.observation; import io.micrometer.core.instrument.composite.CompositeMeterRegistry; @@ -20,7 +21,12 @@ import io.micrometer.tracing.otel.bridge.OtelTracer; import io.opentelemetry.api.OpenTelemetry; import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.observation.*; + +import org.springframework.ai.chat.observation.ChatModelCompletionObservationFilter; +import org.springframework.ai.chat.observation.ChatModelCompletionObservationHandler; +import org.springframework.ai.chat.observation.ChatModelMeterObservationHandler; +import org.springframework.ai.chat.observation.ChatModelPromptContentObservationFilter; +import org.springframework.ai.chat.observation.ChatModelPromptContentObservationHandler; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -38,35 +44,35 @@ class ChatObservationAutoConfigurationTests { @Test void meterObservationHandlerEnabled() { - contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { + this.contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { assertThat(context).hasSingleBean(ChatModelMeterObservationHandler.class); }); } @Test void meterObservationHandlerDisabled() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelMeterObservationHandler.class); }); } @Test void promptFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationFilter.class); }); } @Test void promptHandlerDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationHandler.class); }); } @Test void promptHandlerEnabled() { - contextRunner + this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.chat.observations.include-prompt=true") .run(context -> { @@ -76,28 +82,28 @@ void promptHandlerEnabled() { @Test void promptHandlerDisabled() { - contextRunner.withPropertyValues("spring.ai.chat.observations.include-prompt=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.observations.include-prompt=true").run(context -> { assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationHandler.class); }); } @Test void completionFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelCompletionObservationFilter.class); }); } @Test void completionHandlerDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelCompletionObservationHandler.class); }); } @Test void completionHandlerEnabled() { - contextRunner + this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.chat.observations.include-completion=true") .run(context -> { @@ -107,7 +113,7 @@ void completionHandlerEnabled() { @Test void completionHandlerDisabled() { - contextRunner.withPropertyValues("spring.ai.chat.observations.include-completion=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.observations.include-completion=true").run(context -> { assertThat(context).doesNotHaveBean(ChatModelCompletionObservationHandler.class); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java index c479690f63f..ad19103371b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.embedding.observation; import io.micrometer.core.instrument.composite.CompositeMeterRegistry; import org.junit.jupiter.api.Test; + import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -35,14 +37,14 @@ class EmbeddingObservationAutoConfigurationTests { @Test void meterObservationHandlerEnabled() { - contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { + this.contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { assertThat(context).hasSingleBean(EmbeddingModelMeterObservationHandler.class); }); } @Test void meterObservationHandlerDisabled() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(EmbeddingModelMeterObservationHandler.class); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfigurationIT.java index 300962345ca..a0b5c014d8d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.huggingface; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.huggingface; import java.util.List; import java.util.stream.Collectors; @@ -25,6 +24,8 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -34,7 +35,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_CHAT_URL", matches = ".+") @@ -43,7 +44,7 @@ public class HuggingfaceChatAutoConfigurationIT { private static final Log logger = LogFactory.getLog(HuggingfaceChatAutoConfigurationIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( - // @formatter:off + // @formatter:off "spring.ai.huggingface.chat.api-key=" + System.getenv("HUGGINGFACE_API_KEY"), "spring.ai.huggingface.chat.url=" + System.getenv("HUGGINGFACE_CHAT_URL")) // @formatter:on @@ -51,7 +52,7 @@ public class HuggingfaceChatAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { HuggingfaceChatModel chatModel = context.getBean(HuggingfaceChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -62,7 +63,7 @@ void generate() { @Disabled("Until streaming support is added") @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { HuggingfaceChatModel chatModel = context.getBean(HuggingfaceChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java index 0c26b992aac..b4bd232039a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.image.observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.observation.ImageModelPromptContentObservationFilter; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -34,14 +36,14 @@ class ImageObservationAutoConfigurationTests { @Test void promptFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ImageModelPromptContentObservationFilter.class); }); } @Test void promptFilterEnabled() { - contextRunner.withPropertyValues("spring.ai.image.observations.include-prompt=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.image.observations.include-prompt=true").run(context -> { assertThat(context).hasSingleBean(ImageModelPromptContentObservationFilter.class); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java index 6d97b82dec5..167102cb8c7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -31,10 +37,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -53,7 +55,7 @@ public class FunctionCallbackInPromptIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -70,7 +72,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -79,7 +81,7 @@ void functionCallTest() { @Test void streamingFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -104,7 +106,7 @@ void streamingFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -112,4 +114,4 @@ void streamingFunctionCallTest() { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java index ceb00d47632..5b33d4c673b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -35,11 +42,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -60,7 +62,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { // FIXME: multiple function calls may stop prematurely due to model performance @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -71,7 +73,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -79,7 +81,7 @@ void functionCallTest() { response = chatModel.call(new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -88,7 +90,7 @@ void functionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -102,14 +104,14 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); }); } // FIXME: multiple function calls may stop prematurely due to model performance @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -128,7 +130,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -146,7 +148,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -173,4 +175,4 @@ public Function weather } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java index 612622d6906..780476ea2fe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -34,10 +40,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -57,7 +59,7 @@ public class FunctionCallbackWrapperIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -67,7 +69,7 @@ void functionCallTest() { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -76,7 +78,7 @@ void functionCallTest() { @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -94,7 +96,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -118,4 +120,4 @@ public FunctionCallback weatherFunctionInfo() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java index 507ed41ff8c..06e8c2b3ec7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.List; +import java.util.stream.Collectors; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -29,10 +35,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -51,7 +53,7 @@ public class MiniMaxAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -61,7 +63,7 @@ void generate() { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { @@ -75,7 +77,7 @@ void generateStreaming() { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MiniMaxEmbeddingModel embeddingModel = context.getBean(MiniMaxEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxPropertiesTests.java index f8a2f5e2a18..47131bedc32 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import org.junit.jupiter.api.Test; import org.skyscreamer.jsonassert.JSONAssert; import org.skyscreamer.jsonassert.JSONCompareMode; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.minimax.MiniMaxChatModel; import org.springframework.ai.minimax.MiniMaxEmbeddingModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java index b7f792d3aca..61a5394dba8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * Mock 3rd party weather service. * @@ -30,16 +31,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Get the weather in location") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -67,28 +73,25 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Get the weather in location") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java index 5244180e734..441f88450ac 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai; import java.util.List; import java.util.stream.Collectors; @@ -24,6 +23,8 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; @@ -33,7 +34,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -50,7 +51,7 @@ public class MistralAiAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -60,7 +61,7 @@ void generate() { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { @@ -74,7 +75,7 @@ void generateStreaming() { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MistralAiEmbeddingModel embeddingModel = context.getBean(MistralAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java index ee4cf9aa7ba..e69a711ffa0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java index 69709a91070..caa18b90b3d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai.tool; import java.util.List; import java.util.Map; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonProperty; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -38,11 +39,16 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import com.fasterxml.jackson.annotation.JsonProperty; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".*") class PaymentStatusBeanIT { + // Assuming we have the following data + public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", + new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", + new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); + private final Logger logger = LoggerFactory.getLogger(PaymentStatusBeanIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() @@ -53,7 +59,7 @@ class PaymentStatusBeanIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue()) .run(context -> { @@ -66,33 +72,20 @@ void functionCallTest() { .withFunction("retrievePaymentDate") .build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); }); } - // Assuming we have the following data - public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", - new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", - new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); - record StatusDate(String status, String date) { + } @Configuration static class Config { - public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { - } - - public record Status(@JsonProperty(required = true, value = "status") String status) { - } - - public record Date(@JsonProperty(required = true, value = "date") String date) { - } - @Bean @Description("Get payment status of a transaction") public Function retrievePaymentStatus() { @@ -105,6 +98,18 @@ public Function retrievePaymentDate() { return (transaction) -> new Date(DATA.get(transaction.transactionId).date()); } + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + public record Date(@JsonProperty(required = true, value = "date") String date) { + + } + } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java index 3fc46b03b8e..5a428b91b3d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai.tool; import java.util.List; import java.util.Map; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonProperty; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -38,7 +39,7 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import com.fasterxml.jackson.annotation.JsonProperty; +import static org.assertj.core.api.Assertions.assertThat; /** * Same test as {@link PaymentStatusBeanIT.java} but using {@link OpenAiChatModel} for @@ -49,6 +50,11 @@ @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".*") class PaymentStatusBeanOpenAiIT { + // Assuming we have the following data + public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", + new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", + new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); + private final Logger logger = LoggerFactory.getLogger(PaymentStatusBeanIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() @@ -60,7 +66,7 @@ class PaymentStatusBeanOpenAiIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.chat.options.model=" + MistralAiApi.ChatModel.SMALL.getValue()) .run(context -> { @@ -73,33 +79,20 @@ void functionCallTest() { .withFunction("retrievePaymentDate") .build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); }); } - // Assuming we have the following data - public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", - new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", - new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); - record StatusDate(String status, String date) { + } @Configuration static class Config { - public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { - } - - public record Status(@JsonProperty(required = true, value = "status") String status) { - } - - public record Date(@JsonProperty(required = true, value = "date") String date) { - } - @Bean @Description("Get payment status of a transaction") public Function retrievePaymentStatus() { @@ -112,6 +105,18 @@ public Function retrievePaymentDate() { return (transaction) -> new Date(DATA.get(transaction.transactionId).date()); } + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + public record Date(@JsonProperty(required = true, value = "date") String date) { + + } + } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java index 0cf0d18b0f6..2efdad49317 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai.tool; import java.util.List; import java.util.Map; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonProperty; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -36,35 +37,26 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import com.fasterxml.jackson.annotation.JsonProperty; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".*") public class PaymentStatusPromptIT { - private final Logger logger = LoggerFactory.getLogger(WeatherServicePromptIT.class); - - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY")) - .withConfiguration(AutoConfigurations.of(MistralAiAutoConfiguration.class)); - - public record Transaction(@JsonProperty(required = true, value = "transaction_id") String id) { - } - - public record Status(@JsonProperty(required = true, value = "status") String status) { - } - - record StatusDate(String status, String date) { - } - // Assuming we have the following payment data. public static final Map DATA = Map.of(new Transaction("T1001"), new StatusDate("Paid", "2021-10-05"), new Transaction("T1002"), new StatusDate("Unpaid", "2021-10-06"), new Transaction("T1003"), new StatusDate("Paid", "2021-10-07"), new Transaction("T1004"), new StatusDate("Paid", "2021-10-05"), new Transaction("T1005"), new StatusDate("Pending", "2021-10-08")); + private final Logger logger = LoggerFactory.getLogger(WeatherServicePromptIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY")) + .withConfiguration(AutoConfigurations.of(MistralAiAutoConfiguration.class)); + @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.SMALL.getValue()) .run(context -> { @@ -74,6 +66,7 @@ void functionCallTest() { var promptOptions = MistralAiChatOptions.builder() .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new Function() { + public Status apply(Transaction transaction) { return new Status(DATA.get(transaction).status()); } @@ -85,11 +78,23 @@ public Status apply(Transaction transaction) { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); }); } -} \ No newline at end of file + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String id) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + record StatusDate(String status, String date) { + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java index 13750c3b0d2..f546eb8bbd1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai.tool; import java.util.List; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; import org.springframework.ai.autoconfigure.mistralai.tool.WeatherServicePromptIT.MyWeatherService.Request; import org.springframework.ai.autoconfigure.mistralai.tool.WeatherServicePromptIT.MyWeatherService.Response; @@ -40,9 +43,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -59,7 +60,7 @@ public class WeatherServicePromptIT { @Test void promptFunctionCall() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue()) .run(context -> { @@ -80,7 +81,7 @@ void promptFunctionCall() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15", "15.0"); // assertThat(response.getResult().getOutput().getContent()).contains("30.0", @@ -90,7 +91,7 @@ void promptFunctionCall() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue()) .run(context -> { @@ -108,7 +109,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15", "15.0"); }); @@ -116,17 +117,6 @@ void functionCallWithPortableFunctionCallingOptions() { public static class MyWeatherService implements Function { - // @formatter:off - public enum Unit { C, F } - - @JsonInclude(Include.NON_NULL) - public record Request( - @JsonProperty(required = true, value = "location") String location, - @JsonProperty(required = true, value = "unit") Unit unit) {} - - public record Response(double temperature, Unit unit) {} - // @formatter:on - @Override public Response apply(Request request) { if (request.location().contains("Paris")) { @@ -141,6 +131,19 @@ else if (request.location().contains("San Francisco")) { throw new IllegalArgumentException("Invalid request: " + request); } + // @formatter:off + public enum Unit { C, F } + + @JsonInclude(Include.NON_NULL) + public record Request( + @JsonProperty(required = true, value = "location") String location, + @JsonProperty(required = true, value = "unit") Unit unit) {} + // @formatter:on + + public record Response(double temperature, Unit unit) { + + } + } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfigurationIT.java index 196e04f82f8..f24c6599cb5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; +import java.util.Objects; +import java.util.stream.Collectors; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -27,10 +33,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -49,7 +51,7 @@ public class MoonshotAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel client = context.getBean(MoonshotChatModel.class); String response = client.call("Hello"); assertThat(response).isNotEmpty(); @@ -59,7 +61,7 @@ void generate() { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel client = context.getBean(MoonshotChatModel.class); Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); String response = Objects.requireNonNull(responseFlux.collectList().block()) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotPropertiesTests.java index 213ccd94516..79e7bd14265 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; import org.junit.jupiter.api.Test; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.moonshot.MoonshotChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java index b31f9c28fc9..2853c4a461b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot.tool; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -32,10 +38,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -54,7 +56,7 @@ public class FunctionCallbackInPromptIT { @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -71,7 +73,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -80,7 +82,7 @@ void functionCallTest() { @Test void streamingFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -105,7 +107,7 @@ void streamingFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -113,4 +115,4 @@ void streamingFunctionCallTest() { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java index e21c283c7c2..e94be42200d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot.tool; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -36,11 +43,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -60,7 +62,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -71,7 +73,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -79,7 +81,7 @@ void functionCallTest() { response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -88,7 +90,7 @@ void functionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -102,13 +104,13 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); }); } @Test void streamFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -127,7 +129,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -145,7 +147,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -172,4 +174,4 @@ public Function weather } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java index a5b7c877959..9de829cc77c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot.tool; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -35,11 +42,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -59,7 +61,7 @@ public class FunctionCallbackWrapperIT { @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -69,7 +71,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -78,7 +80,7 @@ void functionCallTest() { @Test void streamFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -97,7 +99,7 @@ void streamFunctionCallTest() { .map(AssistantMessage::getContent) .filter(Objects::nonNull) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -121,4 +123,4 @@ public FunctionCallback weatherFunctionInfo() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java index 0bdfcbb1cb6..3d8e96ba6e4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot.tool; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * Mock 3rd party weather service. * @@ -30,14 +31,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -65,28 +73,23 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java index d67a8e0e715..d23681cad1f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; import java.nio.file.Paths; @@ -20,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.oci.OCIEmbeddingModel; @@ -40,8 +42,8 @@ public class OCIGenAiAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.oci.genai.authenticationType=file", - "spring.ai.oci.genai.file=" + CONFIG_FILE, - "spring.ai.oci.genai.embedding.compartment=" + COMPARTMENT_ID, + "spring.ai.oci.genai.file=" + this.CONFIG_FILE, + "spring.ai.oci.genai.embedding.compartment=" + this.COMPARTMENT_ID, "spring.ai.oci.genai.embedding.servingMode=on-demand", "spring.ai.oci.genai.embedding.model=cohere.embed-english-light-v2.0" // @formatter:on @@ -49,7 +51,7 @@ public class OCIGenAiAutoConfigurationIT { @Test void embeddings() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OCIEmbeddingModel embeddingModel = context.getBean(OCIEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse response = embeddingModel diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java index f4323403958..3a096dc09b0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java @@ -1,21 +1,33 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.ollama; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; -import org.testcontainers.ollama.OllamaContainer; public class BaseOllamaIT { - // Toggle for running tests locally on native Ollama for a faster feedback loop. - private static final boolean useTestcontainers = true; - public static final OllamaContainer ollamaContainer; - static { - ollamaContainer = new OllamaContainer(OllamaImage.IMAGE).withReuse(true); - ollamaContainer.start(); - } + // Toggle for running tests locally on native Ollama for a faster feedback loop. + private static final boolean useTestcontainers = true; /** * Change the return value to false in order to run multiple Ollama IT tests locally @@ -41,4 +53,9 @@ public static String buildConnectionWithModel(String model) { return baseUrl; } + static { + ollamaContainer = new OllamaContainer(OllamaImage.IMAGE).withReuse(true); + ollamaContainer.start(); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java index 10b07fb4559..b596b14f05f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.ollama; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.ollama; import java.io.IOException; import java.util.List; @@ -24,6 +23,9 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -35,9 +37,8 @@ import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -53,11 +54,6 @@ public class OllamaChatAutoConfigurationIT extends BaseOllamaIT { static String baseUrl; - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - baseUrl = buildConnectionWithModel(MODEL_NAME); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.baseUrl=" + baseUrl, @@ -69,22 +65,27 @@ public static void beforeAll() throws IOException, InterruptedException { private final UserMessage userMessage = new UserMessage("What's the capital of Denmark?"); + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + baseUrl = buildConnectionWithModel(MODEL_NAME); + } + @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); - ChatResponse response = chatModel.call(new Prompt(userMessage)); + ChatResponse response = chatModel.call(new Prompt(this.userMessage)); assertThat(response.getResult().getOutput().getContent()).contains("Copenhagen"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); - Flux response = chatModel.stream(new Prompt(userMessage)); + Flux response = chatModel.stream(new Prompt(this.userMessage)); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -102,7 +103,7 @@ public void chatCompletionStreaming() { @Test public void chatCompletionWithPull() { - contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing") + this.contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing") .withPropertyValues("spring.ai.ollama.chat.options.model=tinyllama") .run(context -> { var model = "tinyllama"; @@ -111,7 +112,7 @@ public void chatCompletionWithPull() { assertThat(modelManager.isModelAvailable(model)).isTrue(); OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); - ChatResponse response = chatModel.call(new Prompt(userMessage)); + ChatResponse response = chatModel.call(new Prompt(this.userMessage)); assertThat(response.getResult().getOutput().getContent()).contains("Copenhagen"); modelManager.deleteModel(model); }); @@ -119,17 +120,17 @@ public void chatCompletionWithPull() { @Test void chatActivation() { - contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=false").run(context -> { assertThat(context.getBeansOfType(OllamaChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaChatModel.class)).isEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(OllamaChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaChatModel.class)).isNotEmpty(); }); - contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=true").run(context -> { assertThat(context.getBeansOfType(OllamaChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaChatModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java index 77e7e06a19b..14493bcb20c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java index 0ea701a6708..0a2521aef01 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import java.io.IOException; @@ -21,12 +22,13 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; -import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.management.OllamaModelManager; import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.OllamaEmbeddingModel; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -46,19 +48,19 @@ public class OllamaEmbeddingAutoConfigurationIT extends BaseOllamaIT { static String baseUrl; - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - baseUrl = buildConnectionWithModel(MODEL_NAME); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.ollama.embedding.options.model=" + MODEL_NAME, "spring.ai.ollama.base-url=" + baseUrl) .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class)); + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + baseUrl = buildConnectionWithModel(MODEL_NAME); + } + @Test public void singleTextEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaEmbeddingModel embeddingModel = context.getBean(OllamaEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); @@ -70,7 +72,7 @@ public void singleTextEmbedding() { @Test public void embeddingWithPull() { - contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing") + this.contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing") .withPropertyValues("spring.ai.ollama.embedding.options.model=all-minilm") .run(context -> { var model = "all-minilm"; @@ -87,17 +89,17 @@ public void embeddingWithPull() { @Test void embeddingActivation() { - contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=false").run(context -> { assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaEmbeddingModel.class)).isEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaEmbeddingModel.class)).isNotEmpty(); }); - contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=true").run(context -> { assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaEmbeddingModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java index 487485a062c..bd2a8bfd2df 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java index fb9db0f36d7..ebabcc72217 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; public class OllamaImage { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java index b7bc4e408bf..53f2973cc5e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.ollama.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.ollama.tool; import java.util.List; import java.util.stream.Collectors; @@ -26,6 +25,9 @@ import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -38,9 +40,8 @@ import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @Testcontainers @DisabledIf("isDisabled") @@ -52,11 +53,6 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT { static String baseUrl; - @BeforeAll - public static void beforeAll() { - baseUrl = buildConnectionWithModel(MODEL_NAME); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.baseUrl=" + baseUrl, @@ -66,9 +62,14 @@ public static void beforeAll() { // @formatter:on .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); + @BeforeAll + public static void beforeAll() { + baseUrl = buildConnectionWithModel(MODEL_NAME); + } + @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -95,7 +96,7 @@ void functionCallTest() { @Disabled("Ollama API does not support streaming function calls yet") @Test void streamingFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -129,4 +130,4 @@ void streamingFunctionCallTest() { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java index 82fd7eb119d..451a970cb75 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.ollama.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.ollama.tool; import java.util.List; import java.util.stream.Collectors; @@ -26,6 +25,9 @@ import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -38,15 +40,13 @@ import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.ollama.OllamaChatModel; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @Testcontainers @DisabledIf("isDisabled") @@ -58,11 +58,6 @@ public class FunctionCallbackWrapperIT extends BaseOllamaIT { static String baseUrl; - @BeforeAll - public static void beforeAll() { - baseUrl = buildConnectionWithModel(MODEL_NAME); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.baseUrl=" + baseUrl, @@ -73,9 +68,14 @@ public static void beforeAll() { .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)) .withUserConfiguration(Config.class); + @BeforeAll + public static void beforeAll() { + baseUrl = buildConnectionWithModel(MODEL_NAME); + } + @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -94,7 +94,7 @@ void functionCallTest() { @Disabled("Ollama API does not support streaming function calls yet") @Test void streamFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -120,7 +120,7 @@ void streamFunctionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -156,4 +156,4 @@ public FunctionCallback weatherFunctionInfo() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java index dc780891bf7..e4a1487c2bd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama.tool; import java.util.function.Function; @@ -30,16 +31,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 10; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -67,28 +73,25 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 10; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 2a428ee9459..e02e4244276 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai; import java.util.Arrays; import java.util.List; @@ -25,6 +24,8 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; @@ -42,7 +43,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class OpenAiAutoConfigurationIT { @@ -55,7 +56,7 @@ public class OpenAiAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -65,7 +66,7 @@ void generate() { @Test void transcribe() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiAudioTranscriptionModel transcriptionModel = context.getBean(OpenAiAudioTranscriptionModel.class); Resource audioFile = new ClassPathResource("/speech/jfk.flac"); String response = transcriptionModel.call(audioFile); @@ -76,7 +77,7 @@ void transcribe() { @Test void speech() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiAudioSpeechModel speechModel = context.getBean(OpenAiAudioSpeechModel.class); byte[] response = speechModel.call("H"); assertThat(response).isNotNull(); @@ -102,7 +103,7 @@ public boolean verifyMp3FrameHeader(byte[] audioResponse) { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { @@ -116,7 +117,7 @@ void generateStreaming() { @Test void streamingWithTokenUsage() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.stream-usage=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.stream-usage=true").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); @@ -138,7 +139,7 @@ void streamingWithTokenUsage() { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiEmbeddingModel embeddingModel = context.getBean(OpenAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel @@ -155,7 +156,7 @@ void embedding() { @Test void generateImage() { - contextRunner.withPropertyValues("spring.ai.openai.image.options.size=1024x1024").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.image.options.size=1024x1024").run(context -> { OpenAiImageModel imageModel = context.getBean(OpenAiImageModel.class); ImageResponse imageResponse = imageModel.call(new ImagePrompt("forest")); assertThat(imageResponse.getResults()).hasSize(1); @@ -167,7 +168,7 @@ void generateImage() { @Test void generateImageWithModel() { // The 256x256 size is supported by dall-e-2, but not by dall-e-3. - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.image.options.model=dall-e-2", "spring.ai.openai.image.options.size=256x256") .run(context -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index 5c38bfee6c3..3ba07792bef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,26 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai; import org.junit.jupiter.api.Test; import org.skyscreamer.jsonassert.JSONAssert; import org.skyscreamer.jsonassert.JSONCompareMode; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.OpenAiAudioSpeechModel; import org.springframework.ai.openai.OpenAiAudioTranscriptionModel; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.OpenAiImageModel; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit Tests for {@link OpenAiConnectionProperties}, {@link OpenAiChatProperties} and * {@link OpenAiEmbeddingProperties}. diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java index 03a3eba1e7b..03d6af56cc8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai; import org.junit.jupiter.api.Test; + import org.springframework.ai.openai.OpenAiAudioSpeechModel; import org.springframework.ai.openai.OpenAiAudioTranscriptionModel; import org.springframework.ai.openai.OpenAiChatModel; @@ -28,6 +28,8 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit Tests for {@link OpenAiChatProperties} #options#responseFormat support. * diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java index cdf76ec241a..538c8456abb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai.tool; import java.util.function.Function; import java.util.stream.Collectors; @@ -24,6 +23,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.openai.OpenAiChatModel; @@ -31,6 +31,8 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class FunctionCallbackInPrompt2IT { @@ -42,7 +44,7 @@ public class FunctionCallbackInPrompt2IT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -60,7 +62,7 @@ void functionCallTest() { .call().content(); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -68,7 +70,7 @@ void functionCallTest() { @Test void functionCallTest2() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -85,7 +87,7 @@ public String apply(MockWeatherService.Request request) { }) .call().content(); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("18"); }); @@ -94,7 +96,7 @@ public String apply(MockWeatherService.Request request) { @Test void streamingFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -107,10 +109,10 @@ void streamingFunctionCallTest() { .collectList().block().stream().collect(Collectors.joining()); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java index e5c0c4fca6a..4de98c17762 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai.tool; import java.util.List; import java.util.stream.Collectors; @@ -24,6 +23,8 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -37,7 +38,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class FunctionCallbackInPromptIT { @@ -50,7 +51,7 @@ public class FunctionCallbackInPromptIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), "spring.ai.openai.chat.options.temperature=0.1") .run(context -> { @@ -70,7 +71,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -79,7 +80,7 @@ void functionCallTest() { @Test void streamingFunctionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), "spring.ai.openai.chat.options.temperature=0.5") .run(context -> { @@ -107,10 +108,10 @@ void streamingFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index beb8292781e..c4b8438214a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -15,8 +15,6 @@ */ package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.List; import java.util.Map; import java.util.function.BiFunction; @@ -27,6 +25,8 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; @@ -46,7 +46,7 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") class FunctionCallbackWithPlainFunctionBeanIT { @@ -269,4 +269,4 @@ public MockWeatherService.Response apply(MockWeatherService.Request request, Too } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java index aaf84d98aeb..0056fc20ba2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai.tool; import java.util.stream.Collectors; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.model.function.FunctionCallback; @@ -34,6 +34,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class FunctionCallbackWrapper2IT { @@ -47,7 +49,7 @@ public class FunctionCallbackWrapper2IT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -62,7 +64,7 @@ void functionCallTest() { .call().content(); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -70,7 +72,7 @@ void functionCallTest() { @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -82,7 +84,7 @@ void streamFunctionCallTest() { .collectList().block().stream().collect(Collectors.joining()); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -103,4 +105,4 @@ public FunctionCallback weatherFunctionInfo() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java index 5020b4b5633..01a1bd1a93e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai.tool; import java.util.List; import java.util.stream.Collectors; @@ -24,6 +23,8 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -40,7 +41,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class FunctionCallbackWrapperIT { @@ -55,7 +56,7 @@ public class FunctionCallbackWrapperIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -64,7 +65,7 @@ void functionCallTest() { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -73,7 +74,7 @@ void functionCallTest() { @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -91,7 +92,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -115,4 +116,4 @@ public FunctionCallback weatherFunctionInfo() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java index 60fd35af1c9..f0026ca9f6e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai.tool; import java.util.function.Function; @@ -30,16 +31,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 10; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -67,28 +73,25 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 10; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfigurationIT.java index 491db782c11..096dea81607 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.postgresml; import java.time.Duration; @@ -66,7 +67,7 @@ public class PostgresMlAutoConfigurationIT { @Test void embedding() { ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withBean(JdbcTemplate.class, () -> jdbcTemplate) + .withBean(JdbcTemplate.class, () -> this.jdbcTemplate) .withConfiguration(AutoConfigurations.of(PostgresMlAutoConfiguration.class)); contextRunner.run(context -> { PostgresMlEmbeddingModel embeddingModel = context.getBean(PostgresMlEmbeddingModel.class); @@ -85,7 +86,7 @@ void embedding() { @Test void embeddingActivation() { - new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> jdbcTemplate) + new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> this.jdbcTemplate) .withConfiguration(AutoConfigurations.of(PostgresMlAutoConfiguration.class)) .withPropertyValues("spring.ai.postgresml.embedding.enabled=false") .run(context -> { @@ -93,7 +94,7 @@ void embeddingActivation() { assertThat(context.getBeansOfType(PostgresMlEmbeddingModel.class)).isEmpty(); }); - new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> jdbcTemplate) + new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> this.jdbcTemplate) .withConfiguration(AutoConfigurations.of(PostgresMlAutoConfiguration.class)) .withPropertyValues("spring.ai.postgresml.embedding.enabled=true") .run(context -> { @@ -101,7 +102,7 @@ void embeddingActivation() { assertThat(context.getBeansOfType(PostgresMlEmbeddingModel.class)).isNotEmpty(); }); - new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> jdbcTemplate) + new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> this.jdbcTemplate) .withConfiguration(AutoConfigurations.of(PostgresMlAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(PostgresMlEmbeddingProperties.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java index 0576e7b5be2..be6a5358595 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.postgresml; import java.util.Map; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java index a854b8ba5f3..002a5f578bb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.image.ImagePrompt; @@ -33,11 +40,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -58,7 +60,7 @@ public class QianFanAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { QianFanChatModel client = context.getBean(QianFanChatModel.class); String response = client.call("Hello"); assertThat(response).isNotEmpty(); @@ -68,7 +70,7 @@ void generate() { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { QianFanChatModel client = context.getBean(QianFanChatModel.class); Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); String response = Objects.requireNonNull(responseFlux.collectList().block()) @@ -82,7 +84,7 @@ void generateStreaming() { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { QianFanEmbeddingModel embeddingClient = context.getBean(QianFanEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingClient @@ -99,7 +101,7 @@ void embedding() { @Test void generateImage() { - contextRunner.withPropertyValues("spring.ai.qianfan.image.options.size=1024x1024").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.qianfan.image.options.size=1024x1024").run(context -> { QianFanImageModel imageModel = context.getBean(QianFanImageModel.class); ImageResponse imageResponse = imageModel.call(new ImagePrompt("forest")); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java index c5acafd78f7..1a0ee813f35 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.junit.jupiter.api.Test; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.qianfan.QianFanChatModel; import org.springframework.ai.qianfan.QianFanEmbeddingModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java index f15e605fc50..64759b68091 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.retry; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java index c663dfb3ab5..c3baab91cb6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.retry; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiAutoConfigurationIT.java index 0423170af5e..34eb6766380 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.image.Image; -import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.stabilityai.StyleEnum; @@ -38,7 +40,7 @@ public class StabilityAiAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ImageModel imageModel = context.getBean(ImageModel.class); StabilityAiImageOptions imageOptions = StabilityAiImageOptions.builder() .withStylePreset(StyleEnum.PHOTOGRAPHIC) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java index 6fbb8509423..c267dd765ac 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfigurationIT.java index 6a7cacabb46..9b3a597bed9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.transformers; import java.io.File; @@ -33,15 +34,15 @@ */ public class TransformersEmbeddingModelAutoConfigurationIT { - @TempDir - File tempDir; - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(TransformersEmbeddingModelAutoConfiguration.class)); + @TempDir + File tempDir; + @Test public void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var properties = context.getBean(TransformersEmbeddingModelProperties.class); assertThat(properties.getCache().isEnabled()).isTrue(); assertThat(properties.getCache().getDirectory()).isEqualTo( @@ -54,14 +55,15 @@ public void embedding() { assertThat(embeddings.size()).isEqualTo(2); // batch size assertThat(embeddings.get(0).length).isEqualTo(embeddingModel.dimensions()); // dimensions - // size + // size }); } @Test public void remoteOnnxModel() { // https://huggingface.co/intfloat/e5-small-v2 - contextRunner.withPropertyValues("spring.ai.embedding.transformer.cache.directory=" + tempDir.getAbsolutePath(), + this.contextRunner.withPropertyValues( + "spring.ai.embedding.transformer.cache.directory=" + this.tempDir.getAbsolutePath(), "spring.ai.embedding.transformer.onnx.modelUri=https://huggingface.co/intfloat/e5-small-v2/resolve/main/model.onnx", "spring.ai.embedding.transformer.tokenizer.uri=https://huggingface.co/intfloat/e5-small-v2/raw/main/tokenizer.json") .run(context -> { @@ -72,8 +74,8 @@ public void remoteOnnxModel() { .isEqualTo("https://huggingface.co/intfloat/e5-small-v2/raw/main/tokenizer.json"); assertThat(properties.getCache().isEnabled()).isTrue(); - assertThat(properties.getCache().getDirectory()).isEqualTo(tempDir.getAbsolutePath()); - assertThat(tempDir.listFiles()).hasSize(2); + assertThat(properties.getCache().getDirectory()).isEqualTo(this.tempDir.getAbsolutePath()); + assertThat(this.tempDir.listFiles()).hasSize(2); EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class); assertThat(embeddingModel).isInstanceOf(TransformersEmbeddingModel.class); @@ -84,23 +86,23 @@ public void remoteOnnxModel() { assertThat(embeddings.size()).isEqualTo(2); // batch size assertThat(embeddings.get(0).length).isEqualTo(embeddingModel.dimensions()); // dimensions - // size + // size }); } @Test void embeddingActivation() { - contextRunner.withPropertyValues("spring.ai.embedding.transformer.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.embedding.transformer.enabled=false").run(context -> { assertThat(context.getBeansOfType(TransformersEmbeddingModelProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(TransformersEmbeddingModel.class)).isEmpty(); }); - contextRunner.withPropertyValues("spring.ai.embedding.transformer.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.embedding.transformer.enabled=true").run(context -> { assertThat(context.getBeansOfType(TransformersEmbeddingModelProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(TransformersEmbeddingModel.class)).isNotEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(TransformersEmbeddingModelProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(TransformersEmbeddingModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java index 46ad43c9cfa..1590d750d3b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.azure; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.azure; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,10 +23,12 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -44,7 +43,9 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -55,6 +56,13 @@ @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+") public class AzureVectorStoreAutoConfigurationIT { + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(AzureVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.azure.apiKey=" + System.getenv("AZURE_AI_SEARCH_API_KEY"), + "spring.ai.vectorstore.azure.url=" + System.getenv("AZURE_AI_SEARCH_ENDPOINT")) + .withPropertyValues("spring.ai.vectorstore.azure.initialize-schema=true"); + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -70,13 +78,6 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(AzureVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.ai.vectorstore.azure.apiKey=" + System.getenv("AZURE_AI_SEARCH_API_KEY"), - "spring.ai.vectorstore.azure.url=" + System.getenv("AZURE_AI_SEARCH_ENDPOINT")) - .withPropertyValues("spring.ai.vectorstore.azure.initialize-schema=true"); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -87,7 +88,7 @@ public static void beforeAll() { @Test public void addAndSearchTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.azure.initializeSchema=true", "spring.ai.vectorstore.azure.indexName=my_test_index", "spring.ai.vectorstore.azure.defaultTopK=6", "spring.ai.vectorstore.azure.defaultSimilarityThreshold=0.75") @@ -106,7 +107,7 @@ public void addAndSearchTest() { assertThat(vectorStore).isInstanceOf(AzureVectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); @@ -120,7 +121,7 @@ public void addAndSearchTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); @@ -131,7 +132,7 @@ public void addAndSearchTest() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java index 36bf8511002..47b56c74b5f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.cassandra; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.cassandra; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,12 +39,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.containers.CassandraContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Mick Semb Wever @@ -56,11 +57,6 @@ class CassandraVectorStoreAutoConfigurationIT { @Container static CassandraContainer cassandraContainer = new CassandraContainer(DEFAULT_IMAGE_NAME.withTag("5.0")); - List documents = List.of( - new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( - ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration( AutoConfigurations.of(CassandraVectorStoreAutoConfiguration.class, CassandraAutoConfiguration.class)) @@ -69,9 +65,14 @@ class CassandraVectorStoreAutoConfigurationIT { .withPropertyValues("spring.ai.vectorstore.cassandra.keyspace=test_autoconfigure") .withPropertyValues("spring.ai.vectorstore.cassandra.contentColumnName=doc_chunk"); + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + @Test void addAndSearch() { - contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) + this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) .withPropertyValues("spring.cassandra.port=" + getContactPointPort()) .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter()) .withPropertyValues("spring.ai.vectorstore.cassandra.fixedThreadPoolExecutorSize=8") @@ -79,7 +80,7 @@ void addAndSearch() { .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.CASSANDRA, VectorStoreObservationContext.Operation.ADD); @@ -89,7 +90,7 @@ void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); @@ -98,7 +99,7 @@ void addAndSearch() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).isEmpty(); @@ -109,6 +110,14 @@ void addAndSearch() { }); } + private String getContactPointHost() { + return cassandraContainer.getContactPoint().getHostString(); + } + + private String getContactPointPort() { + return String.valueOf(cassandraContainer.getContactPoint().getPort()); + } + @Configuration(proxyBeanMethods = false) static class Config { @@ -124,12 +133,4 @@ public EmbeddingModel embeddingModel() { } - private String getContactPointHost() { - return cassandraContainer.getContactPoint().getHostString(); - } - - private String getContactPointPort() { - return String.valueOf(cassandraContainer.getContactPoint().getPort()); - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java index cfa1fc37520..3c8201eda79 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.cassandra; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java index 7b061afe525..a22b26c30cf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java @@ -18,16 +18,14 @@ import java.util.List; import java.util.Map; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; - import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java index 7376cf99a6a..1a8e0920123 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,10 +16,16 @@ package org.springframework.ai.autoconfigure.vectorstore.cosmosdb; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -30,10 +36,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -61,8 +64,8 @@ public class CosmosDBVectorStoreAutoConfigurationIT { @BeforeEach public void setup() { - contextRunner.run(context -> { - vectorStore = context.getBean(VectorStore.class); + this.contextRunner.run(context -> { + this.vectorStore = context.getBean(VectorStore.class); }); } @@ -74,20 +77,20 @@ public void testAddSearchAndDeleteDocuments() { Document document2 = new Document(UUID.randomUUID().toString(), "Sample content2", Map.of("key2", "value2")); // Add the document to the vector store - vectorStore.add(List.of(document1, document2)); + this.vectorStore.add(List.of(document1, document2)); // Perform a similarity search - List results = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); // Verify the search results assertThat(results).isNotEmpty(); assertThat(results.get(0).getId()).isEqualTo(document1.getId()); // Remove the documents from the vector store - vectorStore.delete(List.of(document1.getId(), document2.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId())); // Perform a similarity search again - List results2 = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); // Verify the search results assertThat(results2).isEmpty(); @@ -126,16 +129,16 @@ void testSimilaritySearchWithFilter() { Document document3 = new Document("3", "A document about the US", metadata3); Document document4 = new Document("4", "A document about the US", metadata4); - vectorStore.add(List.of(document1, document2, document3, document4)); + this.vectorStore.add(List.of(document1, document2, document3, document4)); FilterExpressionBuilder b = new FilterExpressionBuilder(); - List results = vectorStore.similaritySearch(SearchRequest.query("The World") + List results = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression((b.in("country", "UK", "NL")).build())); assertThat(results).hasSize(2); assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2"); - List results2 = vectorStore.similaritySearch(SearchRequest.query("The World") + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression( b.and(b.or(b.gte("year", 2021), b.eq("country", "NL")), b.ne("city", "Amsterdam")).build())); @@ -143,17 +146,17 @@ void testSimilaritySearchWithFilter() { assertThat(results2).hasSize(1); assertThat(results2).extracting(Document::getId).containsExactlyInAnyOrder("1"); - List results3 = vectorStore.similaritySearch(SearchRequest.query("The World") + List results3 = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression(b.and(b.eq("country", "US"), b.eq("year", 2020)).build())); assertThat(results3).hasSize(1); assertThat(results3).extracting(Document::getId).containsExactlyInAnyOrder("4"); - vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); // Perform a similarity search again - List results4 = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)); + List results4 = this.vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)); // Verify the search results assertThat(results4).isEmpty(); @@ -174,4 +177,4 @@ public TestObservationRegistry observationRegistry() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java index 60eca7d06b1..08b261c464a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.elasticsearch.ElasticsearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; @@ -42,11 +44,10 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.elasticsearch.ElasticsearchContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; @Testcontainers @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -57,11 +58,6 @@ class ElasticsearchVectorStoreAutoConfigurationIT { "docker.elastic.co/elasticsearch/elasticsearch:8.12.2") .withEnv("xpack.security.enabled", "false"); - private List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(ElasticsearchRestClientAutoConfiguration.class, ElasticsearchVectorStoreAutoConfiguration.class, RestClientAutoConfiguration.class, @@ -71,6 +67,11 @@ class ElasticsearchVectorStoreAutoConfigurationIT { "spring.ai.vectorstore.elasticsearch.initializeSchema=true", "spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY")); + private List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + // No parametrized test based on similarity function, // by default the bean will be created using cosine. @Test @@ -80,7 +81,7 @@ public void addAndSearchTest() { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.ELASTICSEARCH, VectorStoreObservationContext.Operation.ADD); @@ -98,7 +99,7 @@ public void addAndSearchTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); @@ -109,7 +110,7 @@ public void addAndSearchTest() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.ELASTICSEARCH, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java index b280052ed6d..54dbc442041 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java @@ -16,19 +16,23 @@ package org.springframework.ai.autoconfigure.vectorstore.gemfire; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; - import java.util.HashMap; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.dockerjava.api.model.ExposedPort; +import com.github.dockerjava.api.model.PortBinding; +import com.github.dockerjava.api.model.Ports; +import com.vmware.gemfire.testcontainers.GemFireCluster; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -43,14 +47,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.dockerjava.api.model.ExposedPort; -import com.github.dockerjava.api.model.PortBinding; -import com.github.dockerjava.api.model.Ports; -import com.vmware.gemfire.testcontainers.GemFireCluster; - -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Geet Rawat @@ -59,8 +58,6 @@ */ class GemFireVectorStoreAutoConfigurationIT { - private static GemFireCluster gemFireCluster; - private static final String INDEX_NAME = "spring-ai-index"; private static final int BEAM_WIDTH = 50; @@ -79,15 +76,7 @@ class GemFireVectorStoreAutoConfigurationIT { private static final int SERVER_COUNT = 1; - @AfterAll - public static void stopGemFireCluster() { - gemFireCluster.close(); - } - - List documents = List.of( - new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( - ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + private static GemFireCluster gemFireCluster; private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(GemFireVectorStoreAutoConfiguration.class)) @@ -102,6 +91,16 @@ public static void stopGemFireCluster() { .withPropertyValues("spring.ai.vectorstore.gemfire.port=" + HTTP_SERVICE_PORT) .withPropertyValues("spring.ai.vectorstore.gemfire.initialize-schema=true"); + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + + @AfterAll + public static void stopGemFireCluster() { + gemFireCluster.close(); + } + @BeforeAll public static void startGemFireCluster() { Ports.Binding hostPort = Ports.Binding.bindPort(HTTP_SERVICE_PORT); @@ -144,11 +143,11 @@ void ensureGemFireVectorStoreCustomConfiguration() { @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.GEMFIRE, VectorStoreObservationContext.Operation.ADD); @@ -166,14 +165,14 @@ public void addAndSearchTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.GEMFIRE, VectorStoreObservationContext.Operation.DELETE); @@ -190,18 +189,24 @@ private Map parseIndex(String json) { JsonNode rootNode = new ObjectMapper().readTree(json); Map indexDetails = new HashMap<>(); if (rootNode.isObject()) { - if (rootNode.has("name")) + if (rootNode.has("name")) { indexDetails.put("name", rootNode.get("name").asText()); - if (rootNode.has("beam-width")) + } + if (rootNode.has("beam-width")) { indexDetails.put("beam-width", rootNode.get("beam-width").asInt()); - if (rootNode.has("max-connections")) + } + if (rootNode.has("max-connections")) { indexDetails.put("max-connections", rootNode.get("max-connections").asInt()); - if (rootNode.has("vector-similarity-function")) + } + if (rootNode.has("vector-similarity-function")) { indexDetails.put("vector-similarity-function", rootNode.get("vector-similarity-function").asText()); - if (rootNode.has("buckets")) + } + if (rootNode.has("buckets")) { indexDetails.put("buckets", rootNode.get("buckets").asInt()); - if (rootNode.has("number-of-embeddings")) + } + if (rootNode.has("number-of-embeddings")) { indexDetails.put("number-of-embeddings", rootNode.get("number-of-embeddings").asInt()); + } } return indexDetails; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java index c8ef301c07a..5f69d6ebbef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java @@ -16,12 +16,12 @@ package org.springframework.ai.autoconfigure.vectorstore.gemfire; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.jupiter.api.Test; import org.springframework.ai.vectorstore.GemFireVectorStore; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Geet Rawat * @author Soby Chacko diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfigurationIT.java index 5801f88f14c..c8d620d0823 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.hanadb; +import java.util.List; + import org.junit.Test; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; @@ -28,8 +32,6 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import java.util.List; - @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "HANA_DATASOURCE_URL", matches = ".+") @EnabledIfEnvironmentVariable(named = "HANA_DATASOURCE_USERNAME", matches = ".+") @@ -37,22 +39,6 @@ @Disabled public class HanaCloudVectorStoreAutoConfigurationIT { - @Test - public void addAndSearch() { - contextRunner.run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); - - List results = vectorStore.similaritySearch("What is Great Depression?"); - Assertions.assertEquals(1, results.size()); - - // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); - List results2 = vectorStore.similaritySearch("Great Depression"); - Assertions.assertEquals(0, results2.size()); - }); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(HanaCloudVectorStoreAutoConfiguration.class, OpenAiAutoConfiguration.class, RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, @@ -70,4 +56,20 @@ public void addAndSearch() { new Document( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression")); + @Test + public void addAndSearch() { + this.contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.add(this.documents); + + List results = vectorStore.similaritySearch("What is Great Depression?"); + Assertions.assertEquals(1, results.size()); + + // Remove all documents from the store + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); + List results2 = vectorStore.similaritySearch("Great Depression"); + Assertions.assertEquals(0, results2.size()); + }); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStorePropertiesTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStorePropertiesTest.java index 1756e8cd7dc..eab9f8f3e11 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStorePropertiesTest.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStorePropertiesTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.hanadb; import org.junit.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java index 6697dffbca6..15723b9b0d2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.milvus; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.milvus; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -34,11 +37,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.milvus.MilvusContainer; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -52,18 +53,18 @@ public class MilvusVectorStoreAutoConfigurationIT { @Container private static MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.3.8"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(MilvusVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class); + List documents = List.of( new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(MilvusVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class); - @Test public void addAndSearch() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.milvus.metricType=COSINE", "spring.ai.vectorstore.milvus.indexType=IVF_FLAT", "spring.ai.vectorstore.milvus.embeddingDimension=384", @@ -75,7 +76,7 @@ public void addAndSearch() { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.MILVUS, VectorStoreObservationContext.Operation.ADD); @@ -85,7 +86,7 @@ public void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); @@ -96,7 +97,7 @@ public void addAndSearch() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(0); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfigurationIT.java index 9e10fcc06a7..c47f3ea25a7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.mongo; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.mongo; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; @@ -40,11 +43,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.data.mongodb.core.MongoTemplate; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.tck.TestObservationRegistry; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Eddú Meléndez @@ -59,6 +60,19 @@ class MongoDBAtlasVectorStoreAutoConfigurationIT { @Container static MongoDBAtlasLocalContainer mongo = new MongoDBAtlasLocalContainer("mongodb/mongodb-atlas-local:7.0.9"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class) + .withConfiguration(AutoConfigurations.of(MongoAutoConfiguration.class, MongoDataAutoConfiguration.class, + MongoDBAtlasVectorStoreAutoConfiguration.class, RestClientAutoConfiguration.class, + SpringAiRetryAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withPropertyValues("spring.data.mongodb.database=springaisample", + "spring.ai.vectorstore.mongodb.initialize-schema=true", + "spring.ai.vectorstore.mongodb.collection-name=test_collection", + // "spring.ai.vectorstore.mongodb.path-name=testembedding", + "spring.ai.vectorstore.mongodb.index-name=text_index", + "spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY"), + String.format("spring.data.mongodb.uri=" + mongo.getConnectionString())); + List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")), @@ -73,27 +87,14 @@ class MongoDBAtlasVectorStoreAutoConfigurationIT { "Testcontainers Testcontainers Testcontainers Testcontainers Testcontainers Testcontainers Testcontainers", Collections.singletonMap("foo", "baz"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class) - .withConfiguration(AutoConfigurations.of(MongoAutoConfiguration.class, MongoDataAutoConfiguration.class, - MongoDBAtlasVectorStoreAutoConfiguration.class, RestClientAutoConfiguration.class, - SpringAiRetryAutoConfiguration.class, OpenAiAutoConfiguration.class)) - .withPropertyValues("spring.data.mongodb.database=springaisample", - "spring.ai.vectorstore.mongodb.initialize-schema=true", - "spring.ai.vectorstore.mongodb.collection-name=test_collection", - // "spring.ai.vectorstore.mongodb.path-name=testembedding", - "spring.ai.vectorstore.mongodb.index-name=text_index", - "spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY"), - String.format("spring.data.mongodb.uri=" + mongo.getConnectionString())); - @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.MONGODB, VectorStoreObservationContext.Operation.ADD); observationRegistry.clear(); @@ -104,7 +105,7 @@ public void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsEntry("meta2", "meta2"); @@ -114,7 +115,7 @@ public void addAndSearch() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).collect(Collectors.toList())); + vectorStore.delete(this.documents.stream().map(Document::getId).collect(Collectors.toList())); assertObservationRegistry(observationRegistry, VectorStoreProvider.MONGODB, VectorStoreObservationContext.Operation.DELETE); @@ -129,29 +130,32 @@ public void addAndSearch() { @Test public void addAndSearchWithFilters() { - contextRunner.withPropertyValues("spring.ai.vectorstore.mongodb.metadata-fields-to-filter=foo").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vectorstore.mongodb.metadata-fields-to-filter=foo") + .run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.add(this.documents); - Thread.sleep(5000); // Await a second for the document to be indexed + Thread.sleep(5000); // Await a second for the document to be indexed - List results = vectorStore.similaritySearch(SearchRequest.query("Testcontainers").withTopK(2)); - assertThat(results).hasSize(2); - results.forEach(doc -> assertThat(doc.getContent().contains("Testcontainers")).isTrue()); + List results = vectorStore + .similaritySearch(SearchRequest.query("Testcontainers").withTopK(2)); + assertThat(results).hasSize(2); + results.forEach(doc -> assertThat(doc.getContent().contains("Testcontainers")).isTrue()); - FilterExpressionBuilder b = new FilterExpressionBuilder(); - results = vectorStore.similaritySearch( - SearchRequest.query("Testcontainers").withTopK(2).withFilterExpression(b.eq("foo", "bar").build())); + FilterExpressionBuilder b = new FilterExpressionBuilder(); + results = vectorStore.similaritySearch(SearchRequest.query("Testcontainers") + .withTopK(2) + .withFilterExpression(b.eq("foo", "bar").build())); - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(3).getId()); - assertThat(resultDoc.getContent().contains("Testcontainers")).isTrue(); - assertThat(resultDoc.getMetadata()).containsEntry("foo", "bar"); + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(3).getId()); + assertThat(resultDoc.getContent().contains("Testcontainers")).isTrue(); + assertThat(resultDoc.getMetadata()).containsEntry("foo", "bar"); - context.getBean(MongoTemplate.class).dropCollection("test_collection"); - }); + context.getBean(MongoTemplate.class).dropCollection("test_collection"); + }); } @Configuration(proxyBeanMethods = false) @@ -164,4 +168,4 @@ public TestObservationRegistry observationRegistry() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfigurationIT.java index 3f56ca28498..09b12c83a23 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.neo4j; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.neo4j; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.Neo4jContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,12 +39,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.containers.Neo4jContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Jingzhou Ou @@ -57,11 +58,6 @@ public class Neo4jVectorStoreAutoConfigurationIT { static Neo4jContainer neo4jContainer = new Neo4jContainer<>(DockerImageName.parse("neo4j:5.18")) .withRandomPassword(); - List documents = List.of( - new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( - ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(Neo4jAutoConfiguration.class, Neo4jVectorStoreAutoConfiguration.class)) .withUserConfiguration(Config.class) @@ -69,9 +65,14 @@ public class Neo4jVectorStoreAutoConfigurationIT { "spring.ai.vectorstore.neo4j.initialize-schema=true", "spring.neo4j.authentication.username=" + "neo4j", "spring.neo4j.authentication.password=" + neo4jContainer.getAdminPassword()); + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + @Test void addAndSearch() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.neo4j.label=my_test_label", "spring.ai.vectorstore.neo4j.embeddingDimension=384", "spring.ai.vectorstore.neo4j.indexName=customIndexName") @@ -84,7 +85,7 @@ void addAndSearch() { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.NEO4J, VectorStoreObservationContext.Operation.ADD); @@ -94,7 +95,7 @@ void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); @@ -103,7 +104,7 @@ void addAndSearch() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.NEO4J, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java index 01dee102613..b0e4cdafaf1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java @@ -1,27 +1,28 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.vectorstore.observation; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; - /** * @author Christian Tzolov * @since 1.0.0 diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java index 0fdd547abf3..29b8a387895 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vectorstore.observation; import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; import io.opentelemetry.api.OpenTelemetry; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationFilter; import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationHandler; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit tests for {@link VectorStoreObservationAutoConfiguration}. * @@ -38,21 +40,21 @@ class VectorStoreObservationAutoConfigurationTests { @Test void queryResponseFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(VectorStoreQueryResponseObservationFilter.class); }); } @Test void queryResponseHandlerDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(VectorStoreQueryResponseObservationHandler.class); }); } @Test void queryResponseHandlerEnabled() { - contextRunner + this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.vectorstore.observations.include-query-response=true") .run(context -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java index b0f6e0ab1ab..7b23132d7b6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.opensearch; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.Map; + import com.jayway.jsonpath.JsonPath; import net.minidev.json.JSONArray; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.localstack.LocalStackContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -31,16 +43,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.localstack.LocalStackContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; @@ -56,11 +58,6 @@ class AwsOpenSearchVectorStoreAutoConfigurationIT { private static final String DOCUMENT_INDEX = "auto-spring-ai-document-index"; - private List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(OpenSearchVectorStoreAutoConfiguration.class, SpringAiRetryAutoConfiguration.class)) @@ -86,6 +83,11 @@ class AwsOpenSearchVectorStoreAutoConfigurationIT { } """); + private List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + @BeforeAll static void beforeAll() throws IOException, InterruptedException { String[] createDomainCmd = { "awslocal", "opensearch", "create-domain", "--domain-name", @@ -109,7 +111,7 @@ public void addAndSearchTest() { this.contextRunner.run(context -> { OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore @@ -121,14 +123,14 @@ public void addAndSearchTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java index e7c58462184..5445f6de7d1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.opensearch; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.opensearch; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.opensearch.testcontainers.OpensearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -41,13 +45,10 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; -import io.micrometer.observation.tck.TestObservationRegistry; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; @Testcontainers class OpenSearchVectorStoreAutoConfigurationIT { @@ -58,11 +59,6 @@ class OpenSearchVectorStoreAutoConfigurationIT { private static final String DOCUMENT_INDEX = "auto-spring-ai-document-index"; - private List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(OpenSearchVectorStoreAutoConfiguration.class, SpringAiRetryAutoConfiguration.class)) @@ -82,6 +78,11 @@ class OpenSearchVectorStoreAutoConfigurationIT { } """); + private List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + @Test public void addAndSearchTest() { @@ -89,7 +90,7 @@ public void addAndSearchTest() { OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.OPENSEARCH, VectorStoreObservationContext.Operation.ADD); @@ -111,14 +112,14 @@ public void addAndSearchTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.OPENSEARCH, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfigurationIT.java index 6f9952ccc1b..078c381d236 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.oracle; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.oracle; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.oracle.OracleContainer; +import org.testcontainers.utility.MountableFile; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -38,12 +42,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.oracle.OracleContainer; -import org.testcontainers.utility.MountableFile; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -58,11 +59,6 @@ public class OracleVectorStoreAutoConfigurationIT { .withCopyFileToContainer(MountableFile.forClasspathResource("/oracle/initialize.sql"), "/container-entrypoint-initdb.d/initialize.sql"); - List documents = List.of( - new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(getText("classpath:/test/data/time.shelter.txt")), - new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(OracleVectorStoreAutoConfiguration.class, JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) @@ -76,14 +72,29 @@ public class OracleVectorStoreAutoConfigurationIT { String.format("spring.datasource.password=%s", oracle23aiContainer.getPassword()), "spring.datasource.type=oracle.jdbc.pool.OracleDataSource"); + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.ORACLE, VectorStoreObservationContext.Operation.ADD); @@ -94,7 +105,7 @@ public void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); assertObservationRegistry(observationRegistry, VectorStoreProvider.ORACLE, @@ -102,7 +113,7 @@ public void addAndSearch() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.ORACLE, VectorStoreObservationContext.Operation.DELETE); @@ -113,16 +124,6 @@ public void addAndSearch() { }); } - public static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - @Configuration(proxyBeanMethods = false) static class Config { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStorePropertiesTests.java index d7e00a9eff9..b57df38922d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.oracle; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vectorstore.oracle; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.OracleVectorStore; import org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType; import org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreIndexType; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Christian Tzolov */ diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java index 6777b04e15e..25a969fceda 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.pgvector; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.pgvector; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -42,11 +45,9 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -61,6 +62,18 @@ public class PgVectorStoreAutoConfigurationIT { @SuppressWarnings("resource") static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("pgvector/pgvector:pg16"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(PgVectorStoreAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", + "spring.ai.vectorstore.pgvector.initialize-schema=true", + // JdbcTemplate configuration + String.format("spring.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), postgresContainer.getDatabaseName()), + "spring.datasource.username=" + postgresContainer.getUsername(), + "spring.datasource.password=" + postgresContainer.getPassword()); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -76,22 +89,17 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(PgVectorStoreAutoConfiguration.class, - JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", - "spring.ai.vectorstore.pgvector.initialize-schema=true", - // JdbcTemplate configuration - String.format("spring.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), postgresContainer.getDatabaseName()), - "spring.datasource.username=" + postgresContainer.getUsername(), - "spring.datasource.password=" + postgresContainer.getPassword()); + private static boolean isFullyQualifiedTableExists(ApplicationContext context, String schemaName, + String tableName) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + String sql = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)"; + return jdbcTemplate.queryForObject(sql, Boolean.class, schemaName, tableName); + } @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { PgVectorStore vectorStore = context.getBean(PgVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); @@ -100,7 +108,7 @@ public void addAndSearch() { PgVectorStore.DEFAULT_TABLE_NAME)) .isTrue(); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.PG_VECTOR, VectorStoreObservationContext.Operation.ADD); @@ -111,7 +119,7 @@ public void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); assertObservationRegistry(observationRegistry, VectorStoreProvider.PG_VECTOR, @@ -119,7 +127,7 @@ public void addAndSearch() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.PG_VECTOR, VectorStoreObservationContext.Operation.DELETE); @@ -136,7 +144,7 @@ public void customSchemaNames(String schemaTableName) { String schemaName = schemaTableName.split(":")[0]; String tableName = schemaTableName.split(":")[1]; - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.pgvector.schema-name=" + schemaName, "spring.ai.vectorstore.pgvector.table-name=" + tableName) .run(context -> { @@ -150,7 +158,7 @@ public void disableSchemaInitialization(String schemaTableName) { String schemaName = schemaTableName.split(":")[0]; String tableName = schemaTableName.split(":")[1]; - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.pgvector.schema-name=" + schemaName, "spring.ai.vectorstore.pgvector.table-name=" + tableName, "spring.ai.vectorstore.pgvector.initialize-schema=false") @@ -174,11 +182,4 @@ public EmbeddingModel embeddingModel() { } - private static boolean isFullyQualifiedTableExists(ApplicationContext context, String schemaName, - String tableName) { - JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); - String sql = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)"; - return jdbcTemplate.queryForObject(sql, Boolean.class, schemaName, tableName); - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStorePropertiesTests.java index a4e4ddef365..f52f4ce782c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.pgvector; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vectorstore.pgvector; import org.junit.jupiter.api.Test; @@ -23,6 +22,8 @@ import org.springframework.ai.vectorstore.PgVectorStore.PgDistanceType; import org.springframework.ai.vectorstore.PgVectorStore.PgIndexType; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Christian Tzolov */ diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java index 461e02abb6f..e5cc9f97d80 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.pinecone; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.pinecone; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,10 +23,12 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -43,7 +42,9 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -53,6 +54,16 @@ @EnabledIfEnvironmentVariable(named = "PINECONE_API_KEY", matches = ".+") public class PineconeVectorStoreAutoConfigurationIT { + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(PineconeVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.pinecone.apiKey=" + System.getenv("PINECONE_API_KEY"), + "spring.ai.vectorstore.pinecone.environment=gcp-starter", + "spring.ai.vectorstore.pinecone.projectId=814621f", + "spring.ai.vectorstore.pinecone.indexName=spring-ai-test-index", + "spring.ai.vectorstore.pinecone.contentFieldName=customContentField", + "spring.ai.vectorstore.pinecone.distanceMetadataFieldName=customDistanceField"); + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -68,16 +79,6 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(PineconeVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.ai.vectorstore.pinecone.apiKey=" + System.getenv("PINECONE_API_KEY"), - "spring.ai.vectorstore.pinecone.environment=gcp-starter", - "spring.ai.vectorstore.pinecone.projectId=814621f", - "spring.ai.vectorstore.pinecone.indexName=spring-ai-test-index", - "spring.ai.vectorstore.pinecone.contentFieldName=customContentField", - "spring.ai.vectorstore.pinecone.distanceMetadataFieldName=customDistanceField"); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -88,12 +89,12 @@ public static void beforeAll() { @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { PineconeVectorStore vectorStore = context.getBean(PineconeVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.PINECONE, VectorStoreObservationContext.Operation.ADD); @@ -107,7 +108,7 @@ public void addAndSearchTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); @@ -118,7 +119,7 @@ public void addAndSearchTest() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.PINECONE, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java index ce006a43887..ce450bac029 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.pinecone; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vectorstore.pinecone; import java.time.Duration; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.PineconeVectorStore; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Christian Tzolov */ diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfigurationIT.java index 99525be0eeb..6a418ab180e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.qdrant; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.qdrant; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -36,11 +39,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.qdrant.QdrantContainer; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -55,11 +56,6 @@ public class QdrantVectorStoreAutoConfigurationIT { @Container static QdrantContainer qdrantContainer = new QdrantContainer("qdrant/qdrant:v1.9.2"); - List documents = List.of( - new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(getText("classpath:/test/data/time.shelter.txt")), - new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(QdrantVectorStoreAutoConfiguration.class)) .withUserConfiguration(Config.class) @@ -67,14 +63,29 @@ public class QdrantVectorStoreAutoConfigurationIT { "spring.ai.vectorstore.qdrant.initialize-schema=true", "spring.ai.vectorstore.qdrant.host=" + qdrantContainer.getHost()); + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.QDRANT, VectorStoreObservationContext.Operation.ADD); @@ -85,7 +96,7 @@ public void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); assertObservationRegistry(observationRegistry, VectorStoreProvider.QDRANT, @@ -93,7 +104,7 @@ public void addAndSearch() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); assertThat(results).hasSize(0); @@ -103,16 +114,6 @@ public void addAndSearch() { }); } - public static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - @Configuration(proxyBeanMethods = false) static class Config { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreCloudAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreCloudAutoConfigurationIT.java index 34e194c6753..2358a821eb3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreCloudAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreCloudAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import java.io.IOException; @@ -67,6 +68,15 @@ public class QdrantVectorStoreCloudAutoConfigurationIT { // NOTE: The GRPC port (usually 6334) is different from the HTTP port (usually 6333)! private static final int CLOUD_GRPC_PORT = 6334; + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(QdrantVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.qdrant.port=" + CLOUD_GRPC_PORT, + "spring.ai.vectorstore.qdrant.host=" + CLOUD_HOST, + "spring.ai.vectorstore.qdrant.api-key=" + CLOUD_API_KEY, + "spring.ai.vectorstore.qdrant.collection-name=" + COLLECTION_NAME, + "spring.ai.vectorstore.qdrant.initializeSchema=true", "spring.ai.vectorstore.qdrant.use-tls=true"); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -92,48 +102,39 @@ static void setup() throws InterruptedException, ExecutionException { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(QdrantVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.ai.vectorstore.qdrant.port=" + CLOUD_GRPC_PORT, - "spring.ai.vectorstore.qdrant.host=" + CLOUD_HOST, - "spring.ai.vectorstore.qdrant.api-key=" + CLOUD_API_KEY, - "spring.ai.vectorstore.qdrant.collection-name=" + COLLECTION_NAME, - "spring.ai.vectorstore.qdrant.initializeSchema=true", "spring.ai.vectorstore.qdrant.use-tls=true"); + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); assertThat(results).hasSize(0); }); } - public static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - @Configuration(proxyBeanMethods = false) static class Config { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStorePropertiesTests.java index 878f298ab3b..31c3e528262 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfigurationIT.java index 634447464db..41295f128b3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.redis; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.redis; import java.util.List; import java.util.Map; +import com.redis.testcontainers.RedisStackContainer; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,12 +38,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import com.redis.testcontainers.RedisStackContainer; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Julien Ruaux @@ -56,11 +56,6 @@ class RedisVectorStoreAutoConfigurationIT { static RedisStackContainer redisContainer = new RedisStackContainer( RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); - List documents = List.of( - new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( - ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class, RedisVectorStoreAutoConfiguration.class)) .withUserConfiguration(Config.class) @@ -69,13 +64,18 @@ class RedisVectorStoreAutoConfigurationIT { .withPropertyValues("spring.ai.vectorstore.redis.index=myIdx") .withPropertyValues("spring.ai.vectorstore.redis.prefix=doc:"); + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.REDIS, VectorStoreObservationContext.Operation.ADD); @@ -85,7 +85,7 @@ void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); @@ -94,7 +94,7 @@ void addAndSearch() { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.REDIS, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStorePropertiesTests.java index cb36910319c..0b38b40c766 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.redis; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vectorstore.redis; import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Julien Ruaux * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfigurationIT.java index aea63b614a1..54aaaf8a132 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.typesense; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.typesense; import java.time.Duration; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,11 +38,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Pablo Sanchidrian Herrera @@ -57,18 +58,18 @@ public class TypesenseVectorStoreAutoConfigurationIT { .withCommand("--data-dir", "/tmp", "--api-key=xyz", "--enable-cors") .withStartupTimeout(Duration.ofSeconds(100)); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(TypesenseVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class); + List documents = List.of( new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(TypesenseVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class); - @Test public void addAndSearch() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.typesense.embeddingDimension=384", "spring.ai.vectorstore.typesense.collectionName=myTestCollection", "spring.ai.vectorstore.typesense.initialize-schema=true", @@ -80,7 +81,7 @@ public void addAndSearch() { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.TYPESENSE, VectorStoreObservationContext.Operation.ADD); @@ -90,7 +91,7 @@ public void addAndSearch() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); @@ -100,7 +101,7 @@ public void addAndSearch() { VectorStoreObservationContext.Operation.QUERY); observationRegistry.clear(); - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.TYPESENSE, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationIT.java index ba81c747924..02d3b2a3efa 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.weaviate; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,9 +38,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.weaviate.WeaviateContainer; - -import io.micrometer.observation.tck.TestObservationRegistry; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; @@ -68,7 +68,7 @@ public class WeaviateVectorStoreAutoConfigurationIT { @Test public void addAndSearchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { WeaviateVectorStoreProperties properties = context.getBean(WeaviateVectorStoreProperties.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java index 3cb3481a227..d5397b56adb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.embedding; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.embedding; import java.io.File; import java.util.List; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.io.TempDir; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.DocumentEmbeddingRequest; import org.springframework.ai.embedding.EmbeddingOptions; @@ -33,6 +33,8 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Christian Tzolov */ @@ -40,17 +42,17 @@ @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") public class VertexAiTextEmbeddingModelAutoConfigurationIT { - @TempDir - File tempDir; - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.vertex.ai.embedding.project-id=" + System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"), "spring.ai.vertex.ai.embedding.location=" + System.getenv("VERTEX_AI_GEMINI_LOCATION")) .withConfiguration(AutoConfigurations.of(VertexAiEmbeddingAutoConfiguration.class)); + @TempDir + File tempDir; + @Test public void textEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var conntectionProperties = context.getBean(VertexAiEmbeddingConnectionProperties.class); var textEmbeddingProperties = context.getBean(VertexAiTextEmbeddingProperties.class); @@ -69,17 +71,17 @@ public void textEmbedding() { @Test void textEmbeddingActivation() { - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=false").run(context -> { assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isEmpty(); }); - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=true").run(context -> { assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isNotEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isNotEmpty(); }); @@ -88,7 +90,7 @@ void textEmbeddingActivation() { @Test public void multimodalEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var conntectionProperties = context.getBean(VertexAiEmbeddingConnectionProperties.class); var multimodalEmbeddingProperties = context.getBean(VertexAiMultimodalEmbeddingProperties.class); @@ -122,17 +124,17 @@ public void multimodalEmbedding() { @Test void multimodalEmbeddingActivation() { - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=false").run(context -> { assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isEmpty(); }); - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=true").run(context -> { assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isNotEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java index 7ab1ef34fea..a9d9716a515 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini; import java.util.stream.Collectors; @@ -21,12 +22,12 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import reactor.core.publisher.Flux; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -45,7 +46,7 @@ public class VertexAiGeminiAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -55,7 +56,7 @@ void generate() { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java index 66a8c0113a8..4d17b12cf32 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.gemini.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.List; import java.util.function.Function; @@ -24,6 +23,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -37,6 +37,8 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") class FunctionCallWithFunctionBeanIT { @@ -53,7 +55,7 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" // + VertexAiGeminiChatModel.ChatModel.GEMINI_PRO_1_5_PRO.getValue()) + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) .run(context -> { @@ -69,21 +71,21 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction3").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel .call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15"); @@ -93,7 +95,7 @@ void functionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" // + VertexAiGeminiChatModel.ChatModel.GEMINI_PRO_1_5_PRO.getValue()) + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) .run(context -> { @@ -109,14 +111,14 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction3").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -142,4 +144,4 @@ public Function weather } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java index 08a095119ae..34688fcef77 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.gemini.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.List; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -37,6 +37,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") public class FunctionCallWithFunctionWrapperIT { @@ -51,7 +53,7 @@ public class FunctionCallWithFunctionWrapperIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) .run(context -> { @@ -66,7 +68,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -87,4 +89,4 @@ public FunctionCallback weatherFunctionInfo() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java index e72ac44806d..2cde310ab6a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.gemini.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.List; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -34,6 +34,8 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") public class FunctionCallWithPromptFunctionIT { @@ -47,7 +49,7 @@ public class FunctionCallWithPromptFunctionIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) .run(context -> { @@ -75,7 +77,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -83,11 +85,11 @@ void functionCallTest() { response = chatModel .call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java index ed34d7b0cf9..aa78f759467 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.function.Function; @@ -31,14 +32,21 @@ @JsonClassDescription("Get the weather in location") public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -66,28 +74,23 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - - @Override - public Response apply(Request request) { - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPaLm2AutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPaLm2AutoConfigurationIT.java index 5634da8ad07..94dd10d26dd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPaLm2AutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPaLm2AutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.palm2; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.palm2; import java.util.List; @@ -23,12 +22,15 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.vertexai.palm2.VertexAiPaLm2ChatModel; import org.springframework.ai.vertexai.palm2.VertexAiPaLm2EmbeddingModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + // NOTE: works only with US location. Use VPN if you are outside US. @EnabledIfEnvironmentVariable(named = "PALM_API_KEY", matches = ".*") public class VertexAiPaLm2AutoConfigurationIT { @@ -44,7 +46,7 @@ public class VertexAiPaLm2AutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VertexAiPaLm2ChatModel chatModel = context.getBean(VertexAiPaLm2ChatModel.class); String response = chatModel.call("Hello"); @@ -56,7 +58,7 @@ void generate() { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VertexAiPaLm2EmbeddingModel embeddingModel = context.getBean(VertexAiPaLm2EmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel @@ -75,19 +77,19 @@ void embedding() { public void embeddingActivation() { // Disable the embedding auto-configuration. - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.enabled=false").run(context -> { assertThat(context.getBeansOfType(VertexAiPalm2EmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2EmbeddingModel.class)).isEmpty(); }); // The embedding auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(VertexAiPalm2EmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2EmbeddingModel.class)).isNotEmpty(); }); // Explicitly enable the embedding auto-configuration. - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.enabled=true").run(context -> { assertThat(context.getBeansOfType(VertexAiPalm2EmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2EmbeddingModel.class)).isNotEmpty(); }); @@ -97,19 +99,19 @@ public void embeddingActivation() { public void chatActivation() { // Disable the chat auto-configuration. - contextRunner.withPropertyValues("spring.ai.vertex.ai.chat.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.chat.enabled=false").run(context -> { assertThat(context.getBeansOfType(VertexAiPlam2ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2ChatModel.class)).isEmpty(); }); // The chat auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(VertexAiPlam2ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2ChatModel.class)).isNotEmpty(); }); // Explicitly enable the chat auto-configuration. - contextRunner.withPropertyValues("spring.ai.vertex.ai.chat.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.chat.enabled=true").run(context -> { assertThat(context.getBeansOfType(VertexAiPlam2ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2ChatModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java index 1637b204ebf..049fd61c71b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.watsonxai; import org.junit.Test; + import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -51,4 +53,4 @@ public void propertiesTest() { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java index b9ec4e45fe6..f15f82f82e4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; +import java.util.List; +import java.util.stream.Collectors; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -32,10 +38,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -54,7 +56,7 @@ public class ZhiPuAiAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -64,7 +66,7 @@ void generate() { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { @@ -78,7 +80,7 @@ void generateStreaming() { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ZhiPuAiEmbeddingModel embeddingModel = context.getBean(ZhiPuAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel @@ -95,7 +97,7 @@ void embedding() { @Test void generateImage() { - contextRunner.withPropertyValues("spring.ai.zhipuai.image.options.size=1024x1024").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.image.options.size=1024x1024").run(context -> { ZhiPuAiImageModel ImageModel = context.getBean(ZhiPuAiImageModel.class); ImageResponse imageResponse = ImageModel.call(new ImagePrompt("forest")); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java index 2aaf6294622..2cc2c9b21e5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.junit.jupiter.api.Test; import org.skyscreamer.jsonassert.JSONAssert; import org.skyscreamer.jsonassert.JSONCompareMode; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.zhipuai.ZhiPuAiChatModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java index 8dc63f205c4..ca91b63c3d0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai.tool; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -32,10 +38,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -54,7 +56,7 @@ public class FunctionCallbackInPromptIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -71,7 +73,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -80,7 +82,7 @@ void functionCallTest() { @Test void streamingFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -105,7 +107,7 @@ void streamingFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -113,4 +115,4 @@ void streamingFunctionCallTest() { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 7b440c38158..5b2657c6918 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai.tool; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -36,11 +43,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -60,7 +62,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -71,7 +73,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -79,7 +81,7 @@ void functionCallTest() { response = chatModel.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -88,7 +90,7 @@ void functionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -102,13 +104,13 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); }); } @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -127,7 +129,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -145,7 +147,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -172,4 +174,4 @@ public Function weather } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java index 2596f3b8431..9016104f214 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai.tool; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -35,10 +41,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -58,7 +60,7 @@ public class FunctionCallbackWrapperIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -68,7 +70,7 @@ void functionCallTest() { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -77,7 +79,7 @@ void functionCallTest() { @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -95,7 +97,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -119,4 +121,4 @@ public FunctionCallback weatherFunctionInfo() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java index 61f6d6c2db7..75d562648f6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai.tool; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * Mock 3rd party weather service. * @@ -30,16 +31,21 @@ */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -67,28 +73,25 @@ private Unit(String text) { } + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + /** * Weather Function response. */ public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, Unit unit) { - } - @Override - public Response apply(Request request) { - - double temperature = 0; - if (request.location().contains("Paris")) { - temperature = 15; - } - else if (request.location().contains("Tokyo")) { - temperature = 10; - } - else if (request.location().contains("San Francisco")) { - temperature = 30; - } - - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/resources/oracle/initialize.sql b/spring-ai-spring-boot-autoconfigure/src/test/resources/oracle/initialize.sql index ac38a19652f..0b42b6ff7ea 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/resources/oracle/initialize.sql +++ b/spring-ai-spring-boot-autoconfigure/src/test/resources/oracle/initialize.sql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + -- Exit on any errors WHENEVER SQLERROR EXIT SQL.SQLCODE diff --git a/spring-ai-spring-boot-docker-compose/pom.xml b/spring-ai-spring-boot-docker-compose/pom.xml index fb13e0efec7..0e714abf082 100644 --- a/spring-ai-spring-boot-docker-compose/pom.xml +++ b/spring-ai-spring-boot-docker-compose/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java index b861bea6ad2..7d3f650aaa2 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.chroma; import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironment.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironment.java index ddfba20fb40..371ed7c9a3e 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironment.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironment.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.chroma; import java.util.Map; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactory.java index 8de8fb4ec68..7c04ff256bf 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.mongo; import com.mongodb.ConnectionString; + import org.springframework.boot.autoconfigure.mongo.MongoConnectionDetails; import org.springframework.boot.docker.compose.core.RunningService; import org.springframework.boot.docker.compose.service.connection.DockerComposeConnectionDetailsFactory; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactory.java index db84ed58179..6608cfa0a10 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.ollama; import org.springframework.ai.autoconfigure.ollama.OllamaConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactory.java index 71811252482..0fbfe2088dc 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.opensearch; +import java.util.List; + import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchConnectionDetails; import org.springframework.boot.docker.compose.core.RunningService; import org.springframework.boot.docker.compose.service.connection.DockerComposeConnectionDetailsFactory; import org.springframework.boot.docker.compose.service.connection.DockerComposeConnectionSource; -import java.util.List; - /** * @author Eddú Meléndez */ diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironment.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironment.java index 56adc0afe9d..034ba169f9e 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironment.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironment.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.opensearch; import java.util.Map; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactory.java index c73781dac5a..2de9b0a4be9 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.qdrant; import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantEnvironment.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantEnvironment.java index 8752ad38f1f..8005a0ed284 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantEnvironment.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantEnvironment.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.docker.compose.service.connection.qdrant; import java.util.Map; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactory.java index 9de92e6131b..ad7691716a7 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.typesense; import org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironment.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironment.java index 139815a5de6..b8b70e44901 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironment.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironment.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.docker.compose.service.connection.typesense; import java.util.Map; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactory.java index 6c88ff2c1dc..2f8216b7db2 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.weaviate; import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/resources/META-INF/spring.factories b/spring-ai-spring-boot-docker-compose/src/main/resources/META-INF/spring.factories index cf904157592..fcc2bfdc36b 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/resources/META-INF/spring.factories +++ b/spring-ai-spring-boot-docker-compose/src/main/resources/META-INF/spring.factories @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactory=\ org.springframework.ai.docker.compose.service.connection.chroma.ChromaDockerComposeConnectionDetailsFactory,\ org.springframework.ai.docker.compose.service.connection.mongo.MongoDbAtlasLocalDockerComposeConnectionDetailsFactory,\ diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java index 5c015d86291..66e670c467e 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.chroma; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironmentTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironmentTests.java index df3e514d0e4..37d416d13ef 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironmentTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironmentTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.docker.compose.service.connection.chroma; -import org.junit.jupiter.api.Test; +package org.springframework.ai.docker.compose.service.connection.chroma; import java.util.Collections; import java.util.Map; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; class ChromaEnvironmentTests { diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java index 493f8ca54e5..8795dc4f69c 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.chroma; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java index b312d71b4d9..c88bf10cd0f 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java @@ -1,9 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.docker.compose.service.connection.mongo; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.boot.autoconfigure.mongo.MongoConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactoryTests.java index 7f88d6a35f6..9b72828195d 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.ollama; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.ollama.OllamaConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java index e50a8655db9..a162d05d23b 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.opensearch; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java index 7e7ba6c42a9..c1457232776 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.docker.compose.service.connection.opensearch; -import org.junit.jupiter.api.Test; +package org.springframework.ai.docker.compose.service.connection.opensearch; import java.util.Collections; import java.util.Map; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; class OpenSearchEnvironmentTests { diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactoryTests.java index bc907baaae3..7dd990bc06b 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.qdrant; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java index a0c3925fe6e..b766c31cbec 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.typesense; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironmentTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironmentTests.java index 9b65f53c4ac..9ed0920dbbb 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironmentTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironmentTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.docker.compose.service.connection.typesense; -import org.junit.jupiter.api.Test; +package org.springframework.ai.docker.compose.service.connection.typesense; import java.util.Collections; import java.util.Map; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; class TypesenseEnvironmentTests { diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactoryTests.java index d046d8fcb71..58f6a5a4944 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.weaviate; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIntegrationTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIntegrationTests.java index b1279bd7673..8c6289c87d4 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIntegrationTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,14 +26,14 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.io.TempDir; -import org.springframework.boot.autoconfigure.ImportAutoConfiguration; -import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryAutoConfiguration; -import org.springframework.boot.testsupport.DisabledIfProcessUnavailable; import org.testcontainers.utility.DockerImageName; import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplicationShutdownHandlers; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; +import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryAutoConfiguration; +import org.springframework.boot.testsupport.DisabledIfProcessUnavailable; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; @@ -60,17 +60,17 @@ public abstract class AbstractDockerComposeIntegrationTests { private final DockerImageName dockerImageName; + protected AbstractDockerComposeIntegrationTests(String composeResource, DockerImageName dockerImageName) { + this.composeResource = new ClassPathResource(composeResource, getClass()); + this.dockerImageName = dockerImageName; + } + @AfterAll static void shutDown() { SpringApplicationShutdownHandlers shutdownHandlers = SpringApplication.getShutdownHandlers(); ((Runnable) shutdownHandlers).run(); } - protected AbstractDockerComposeIntegrationTests(String composeResource, DockerImageName dockerImageName) { - this.composeResource = new ClassPathResource(composeResource, getClass()); - this.dockerImageName = dockerImageName; - } - protected final T run(Class type) { SpringApplication application = new SpringApplication(Config.class); Map properties = new LinkedHashMap<>(); diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailable.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailable.java index aded7c820e5..857f2d6d9aa 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailable.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailable.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,6 @@ package org.springframework.boot.testsupport; -import org.junit.jupiter.api.extension.ExtendWith; - import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Repeatable; @@ -25,6 +23,8 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.junit.jupiter.api.extension.ExtendWith; + /** * Disables test execution if a process is unavailable. * diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java index 98690a1cc6f..62dd2f20ee6 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,21 +16,22 @@ package org.springframework.boot.testsupport; +import java.lang.reflect.AnnotatedElement; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; + import org.junit.jupiter.api.extension.ConditionEvaluationResult; import org.junit.jupiter.api.extension.ExecutionCondition; import org.junit.jupiter.api.extension.ExtensionContext; + import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.annotation.MergedAnnotations; import org.springframework.core.annotation.MergedAnnotations.SearchStrategy; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.lang.reflect.AnnotatedElement; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.TimeUnit; -import java.util.stream.Stream; - /** * An {@link ExecutionCondition} that disables execution if specified processes cannot * start. diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailables.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailables.java index b62bf3177e2..bfc2e88f297 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailables.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailables.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,14 @@ package org.springframework.boot.testsupport; -import org.junit.jupiter.api.extension.ExtendWith; - import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.junit.jupiter.api.extension.ExtendWith; + /** * Repeatable container for {@link DisabledIfProcessUnavailable}. * diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-anthropic/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-anthropic/pom.xml index 91584cb4540..31e711ba2b7 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-anthropic/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-anthropic/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store/pom.xml index dc5d637b918..186f0bad8ed 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store/pom.xml index ac8c2ffb1c5..e2c3ec1cc2f 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-azure-openai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-azure-openai/pom.xml index 61a3a6d9d57..7ccc8ef96bf 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-azure-openai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-azure-openai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-azure-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-azure-store/pom.xml index d9292532ea7..f5face502c4 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-azure-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-azure-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai/pom.xml index 781dc92a6b0..fc95593d500 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-cassandra-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-cassandra-store/pom.xml index 00b58ba7586..968e4fa5356 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-cassandra-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-cassandra-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-chroma-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-chroma-store/pom.xml index 6dd309936f9..f9e588f8bc7 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-chroma-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-chroma-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store/pom.xml index 363ca6d1151..efdba43b9b9 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml index de48a933f30..b393648d61b 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store/pom.xml index bba6eddcef6..0309e3970bf 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-huggingface/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-huggingface/pom.xml index 18d9e2c0970..0832116aa3c 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-huggingface/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-huggingface/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-milvus-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-milvus-store/pom.xml index 825e9f03040..49c66a4c2c8 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-milvus-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-milvus-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-minimax/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-minimax/pom.xml index 3004b42a9db..124570853ac 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-minimax/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-minimax/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai/pom.xml index 05a61791490..4c00f83262b 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store/pom.xml index 52cb6d05805..7f868e68a0c 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-moonshot/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-moonshot/pom.xml index 2c2b19a84f7..70c34367bf0 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-moonshot/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-moonshot/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store/pom.xml index cc294111481..f0203e6973e 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-oci-genai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-oci-genai/pom.xml index 2fd347de32d..f8a83528786 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-oci-genai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-oci-genai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-ollama/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-ollama/pom.xml index fc35584038d..7b16c0c672d 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-ollama/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-ollama/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml index 95b60e64261..a5bce988897 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml index c97eb81ad68..07533074db8 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-oracle-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-oracle-store/pom.xml index 210cb8301e0..72d2cde2f1f 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-oracle-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-oracle-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-pgvector-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-pgvector-store/pom.xml index a194100e6fe..141e9f2a1b8 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-pgvector-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-pgvector-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-pinecone-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-pinecone-store/pom.xml index aefe6a0a62c..bb1977f6d82 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-pinecone-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-pinecone-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding/pom.xml index 0378542fa4d..ffbc4aa5543 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-qdrant-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-qdrant-store/pom.xml index acbfe9a28d8..fddb080fe56 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-qdrant-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-qdrant-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml index e8a3124671c..0da3b1776ae 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-redis-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-redis-store/pom.xml index 09fe4abb2e8..637b3e44296 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-redis-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-redis-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-stability-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-stability-ai/pom.xml index b2a74b1663d..64764552e2b 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-stability-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-stability-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-transformers/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-transformers/pom.xml index 5cab1733dce..9b0e6164338 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-transformers/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-transformers/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-typesense-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-typesense-store/pom.xml index de5170aa8cb..2b3e525187e 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-typesense-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-typesense-store/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml index f59c3533000..36fcee4d1e9 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini/pom.xml index 3ef367efa41..778b9cccde3 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2/pom.xml index eec7a61c030..97d0f1077c7 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai/pom.xml index 44dc3af3f26..36e871422a4 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-weaviate-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-weaviate-store/pom.xml index ea237a2e159..e46246176c4 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-weaviate-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-weaviate-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml index 060ac290810..9eb4283c110 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-testcontainers/pom.xml b/spring-ai-spring-boot-testcontainers/pom.xml index a56a7af7aac..d1d5d14853d 100644 --- a/spring-ai-spring-boot-testcontainers/pom.xml +++ b/spring-ai-spring-boot-testcontainers/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactory.java index 5efb6b1afd8..909d43f7a30 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.chroma; +import java.util.Map; + +import org.testcontainers.chromadb.ChromaDBContainer; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.chromadb.ChromaDBContainer; - -import java.util.Map; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactory.java index a44643c382e..137f8529676 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.milvus; +import org.testcontainers.milvus.MilvusContainer; + import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.milvus.MilvusContainer; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactory.java index 8bc4b2021c8..bf425a33c6a 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.mongo; import com.mongodb.ConnectionString; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + import org.springframework.boot.autoconfigure.mongo.MongoConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; /** * A {@link ContainerConnectionDetailsFactory} implementation that provides diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactory.java index 46174bc36fb..b800be8b7fb 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.ollama; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.autoconfigure.ollama.OllamaConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.ollama.OllamaContainer; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactory.java index 22898785fcd..a154ddaf0d6 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.opensearch; +import java.util.List; + import org.opensearch.testcontainers.OpensearchContainer; + import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import java.util.List; - /** * @author Eddú Meléndez */ diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactory.java index 619793d1b2c..e6a6bd5368c 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.qdrant; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.qdrant.QdrantContainer; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactory.java index 5779b3e53ff..68769925a9d 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.typesense; +import org.testcontainers.containers.Container; + import org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.containers.Container; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactory.java index 601fb6244da..d953dfa9c7d 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.weaviate; +import org.testcontainers.weaviate.WeaviateContainer; + import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.weaviate.WeaviateContainer; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories b/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories index bf8370e68ae..a4d88b8ef6b 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories +++ b/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactory=\ org.springframework.ai.testcontainers.service.connection.chroma.ChromaContainerConnectionDetailsFactory,\ org.springframework.ai.testcontainers.service.connection.milvus.MilvusContainerConnectionDetailsFactory,\ diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java index 2bf14e3de3c..2c387f2c1cb 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java @@ -15,8 +15,15 @@ */ package org.springframework.ai.testcontainers.service.connection.chroma; +import java.util.List; +import java.util.Map; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,12 +37,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java index 2eef549497e..efd038e5d38 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.chroma; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java index 1c7daf6423f..86c9e2718a5 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java @@ -15,8 +15,15 @@ */ package org.springframework.ai.testcontainers.service.connection.chroma; +import java.util.List; +import java.util.Map; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,12 +37,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java index 34460d6fb30..fbe935032c3 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java @@ -15,8 +15,15 @@ */ package org.springframework.ai.testcontainers.service.connection.chroma; +import java.util.List; +import java.util.Map; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,12 +37,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactoryTest.java index c4aed4e3497..2f3f51910a8 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.milvus; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; @@ -30,12 +38,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.milvus.MilvusContainer; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -61,22 +63,22 @@ class MilvusContainerConnectionDetailsFactoryTest { @Test public void addAndSearch() { - vectorStore.add(documents); + this.vectorStore.add(this.documents); - List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()) .contains("Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(0); } @@ -91,4 +93,4 @@ public EmbeddingModel embeddingModel() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java index e125d08ba7a..168e854e799 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.milvus; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java index 34b748de47e..1f05b5d0aca 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.mongo; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + import org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -32,13 +41,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; - -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -68,10 +70,10 @@ public void addAndSearch() throws InterruptedException { "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression", Collections.singletonMap("meta2", "meta2"))); - vectorStore.add(documents); + this.vectorStore.add(documents); Thread.sleep(5000); // Await a second for the document to be indexed - List results = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -81,9 +83,9 @@ public void addAndSearch() throws InterruptedException { assertThat(resultDoc.getMetadata()).containsEntry("meta2", "meta2"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).collect(Collectors.toList())); + this.vectorStore.delete(documents.stream().map(Document::getId).collect(Collectors.toList())); - List results2 = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results2).isEmpty(); } @@ -99,4 +101,4 @@ public EmbeddingModel embeddingModel() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java index 3ac2aa28450..af0cb68f5ad 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.mongo; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java index 6e139b74796..8cb9f4a124a 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.ollama; +import java.io.IOException; +import java.util.List; + import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.OllamaEmbeddingModel; @@ -30,12 +38,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.ollama.OllamaContainer; - -import java.io.IOException; -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -50,10 +52,10 @@ + OllamaContainerConnectionDetailsFactoryTest.MODEL_NAME) class OllamaContainerConnectionDetailsFactoryTest { - private static final Logger logger = LoggerFactory.getLogger(OllamaContainerConnectionDetailsFactoryTest.class); - static final String MODEL_NAME = "nomic-embed-text"; + private static final Logger logger = LoggerFactory.getLogger(OllamaContainerConnectionDetailsFactoryTest.class); + @Container @ServiceConnection static OllamaContainer ollama = new OllamaContainer(OllamaImage.DEFAULT_IMAGE); @@ -82,4 +84,4 @@ static class Config { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java index 59cf4940360..c1bce0c70f0 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.ollama; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactoryTest.java index 43760317ee1..7e4393beb85 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.testcontainers.service.connection.opensearch; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; +package org.springframework.ai.testcontainers.service.connection.opensearch; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,6 +24,9 @@ import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.opensearch.testcontainers.OpensearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -39,8 +40,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; @SpringBootTest(properties = { "spring.ai.vectorstore.opensearch.index-name=" + OpenSearchContainerConnectionDetailsFactoryTest.DOCUMENT_INDEX, @@ -50,14 +52,14 @@ @Testcontainers class OpenSearchContainerConnectionDetailsFactoryTest { - @Container - @ServiceConnection - private static final OpensearchContainer opensearch = new OpensearchContainer<>(OpenSearchImage.DEFAULT_IMAGE); - static final String DOCUMENT_INDEX = "auto-spring-ai-document-index"; static final String MAPPING_JSON = "{\"properties\":{\"embedding\":{\"type\":\"knn_vector\",\"dimension\":384}}}"; + @Container + @ServiceConnection + private static final OpensearchContainer opensearch = new OpensearchContainer<>(OpenSearchImage.DEFAULT_IMAGE); + private final List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -69,29 +71,29 @@ class OpenSearchContainerConnectionDetailsFactoryTest { @Test public void addAndSearchTest() { - vectorStore.add(documents); + this.vectorStore.add(this.documents); Awaitility.await() - .until(() -> vectorStore + .until(() -> this.vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), hasSize(1)); - List results = vectorStore + List results = this.vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + this.vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore + .until(() -> this.vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), hasSize(0)); } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java index 8c636cbbab0..26a615e6a32 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.opensearch; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactoryTest.java index bdc5a7b10f0..c25773c79d6 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.qdrant; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,14 +40,6 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.qdrant.QdrantContainer; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -59,34 +61,34 @@ public class QdrantContainerConnectionDetailsFactoryTest { @Autowired private VectorStore vectorStore; + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void addAndSearch() { - vectorStore.add(documents); + this.vectorStore.add(this.documents); - List results = vectorStore + List results = this.vectorStore .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + this.vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); + results = this.vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); assertThat(results).hasSize(0); } - public static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - @Configuration(proxyBeanMethods = false) @ImportAutoConfiguration(QdrantVectorStoreAutoConfiguration.class) static class Config { diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerWithApiKeyConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerWithApiKeyConnectionDetailsFactoryTest.java index 4642ec66333..998e220df9b 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerWithApiKeyConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerWithApiKeyConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.qdrant; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,14 +40,6 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.qdrant.QdrantContainer; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -59,34 +61,34 @@ public class QdrantContainerWithApiKeyConnectionDetailsFactoryTest { @Autowired private VectorStore vectorStore; + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void addAndSearch() { - vectorStore.add(documents); + this.vectorStore.add(this.documents); - List results = vectorStore + List results = this.vectorStore .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + this.vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); + results = this.vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); assertThat(results).hasSize(0); } - public static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - @Configuration(proxyBeanMethods = false) @ImportAutoConfiguration(QdrantVectorStoreAutoConfiguration.class) static class Config { diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java index a50b6576aa8..618e61cfccd 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.qdrant; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactoryTest.java index 338b076a7a8..1aae0793ad8 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactoryTest.java @@ -1,6 +1,30 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.testcontainers.service.connection.typesense; +import java.time.Duration; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; @@ -15,13 +39,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.time.Duration; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -51,19 +68,19 @@ class TypesenseContainerConnectionDetailsFactoryTest { @Test public void addAndSearch() { - this.vectorStore.add(documents); + this.vectorStore.add(this.documents); List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()) .contains("Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); - this.vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(0); @@ -80,4 +97,4 @@ public EmbeddingModel embeddingModel() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java index 406596506be..bf9982363b0 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.typesense; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactoryTest.java index 3a35f27d3a0..b30588e4eff 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.weaviate; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; + import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateVectorStoreAutoConfiguration; import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateVectorStoreProperties; import org.springframework.ai.document.Document; @@ -31,13 +40,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.weaviate.WeaviateContainer; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -63,15 +65,15 @@ class WeaviateContainerConnectionDetailsFactoryTest { @Test public void addAndSearchWithFilters() { - assertThat(properties.getFilterField()).hasSize(4); + assertThat(this.properties.getFilterField()).hasSize(4); - assertThat(properties.getFilterField().get("country")) + assertThat(this.properties.getFilterField().get("country")) .isEqualTo(WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField.Type.TEXT); - assertThat(properties.getFilterField().get("year")) + assertThat(this.properties.getFilterField().get("year")) .isEqualTo(WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField.Type.NUMBER); - assertThat(properties.getFilterField().get("active")) + assertThat(this.properties.getFilterField().get("active")) .isEqualTo(WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField.Type.BOOLEAN); - assertThat(properties.getFilterField().get("price")) + assertThat(this.properties.getFilterField().get("price")) .isEqualTo(WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField.Type.NUMBER); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -79,39 +81,39 @@ public void addAndSearchWithFilters() { var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", Map.of("country", "Netherlands", "price", 1.57, "active", false, "year", 2023)); - vectorStore.add(List.of(bgDocument, nlDocument)); + this.vectorStore.add(List.of(bgDocument, nlDocument)); var request = SearchRequest.query("The World").withTopK(5); - List results = vectorStore.similaritySearch(request); + List results = this.vectorStore.similaritySearch(request); assertThat(results).hasSize(2); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Bulgaria'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Netherlands'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); - results = vectorStore.similaritySearch( + results = this.vectorStore.similaritySearch( request.withSimilarityThresholdAll().withFilterExpression("price > 1.57 && active == true")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("year in [2020, 2023]")); assertThat(results).hasSize(2); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("year > 2020 && year <= 2023")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); // Remove all documents from the store - vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); } @Configuration(proxyBeanMethods = false) @@ -125,4 +127,4 @@ public EmbeddingModel embeddingModel() { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java index cdece3ffafc..8157d3c2e8d 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.weaviate; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-cloud-bindings/pom.xml b/spring-ai-spring-cloud-bindings/pom.xml index 2bf03d03c63..da3ffe77fb2 100644 --- a/spring-ai-spring-cloud-bindings/pom.xml +++ b/spring-ai-spring-cloud-bindings/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java index c53fedb3504..bc4dc5087cd 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessor.java index e4cd607b764..67e1da572c4 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,14 @@ package org.springframework.ai.bindings; +import java.net.URI; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.net.URI; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessor.java index 3a22a564fd2..07dabaa3da8 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package org.springframework.ai.bindings; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java index 857fde84bd2..8afe9e393e2 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package org.springframework.ai.bindings; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessor.java index af98292e6f0..d00e04961a6 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package org.springframework.ai.bindings; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessor.java index 8832af47b3f..bb14eadd6e3 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,14 @@ package org.springframework.ai.bindings; +import java.util.Arrays; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.util.Arrays; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessor.java index 1b2160f4d7b..223210a1b71 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,14 @@ package org.springframework.ai.bindings; +import java.net.URI; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.net.URI; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/resources/META-INF/spring.factories b/spring-ai-spring-cloud-bindings/src/main/resources/META-INF/spring.factories index 9562cc660d5..668f29e0e3f 100644 --- a/spring-ai-spring-cloud-bindings/src/main/resources/META-INF/spring.factories +++ b/spring-ai-spring-cloud-bindings/src/main/resources/META-INF/spring.factories @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Binding Properties Factories org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor=\ org.springframework.ai.bindings.ChromaBindingsPropertiesProcessor,\ diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java index d9023c336aa..885d8f25984 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + import org.junit.jupiter.api.Test; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.mock.env.MockEnvironment; -import java.nio.file.Paths; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -51,19 +52,19 @@ class ChromaBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new ChromaBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.vectorstore.chroma.client.host", "https://example.net"); - assertThat(properties).containsEntry("spring.ai.vectorstore.chroma.client.port", "8000"); - assertThat(properties).containsEntry("spring.ai.vectorstore.chroma.client.username", "itsme"); - assertThat(properties).containsEntry("spring.ai.vectorstore.chroma.client.password", "youknowit"); + new ChromaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.chroma.client.host", "https://example.net"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.chroma.client.port", "8000"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.chroma.client.username", "itsme"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.chroma.client.password", "youknowit"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.chroma.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.chroma.enabled".formatted(CONFIG_PATH), "false"); - new ChromaBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new ChromaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java index 05d175f3dc7..0c2d1db356b 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + import org.junit.jupiter.api.Test; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.mock.env.MockEnvironment; -import java.nio.file.Paths; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -50,17 +51,17 @@ class MistralAiBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new MistralAiBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.mistralai.api-key", "demo"); - assertThat(properties).containsEntry("spring.ai.mistralai.base-url", "https://my.mistralai.example.net"); + new MistralAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.mistralai.api-key", "demo"); + assertThat(this.properties).containsEntry("spring.ai.mistralai.base-url", "https://my.mistralai.example.net"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.mistralai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.mistralai.enabled".formatted(CONFIG_PATH), "false"); - new MistralAiBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new MistralAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java index 247c9521088..b308fae9ac0 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + import org.junit.jupiter.api.Test; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.mock.env.MockEnvironment; -import java.nio.file.Paths; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -49,16 +50,16 @@ class OllamaBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new OllamaBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.ollama.base-url", "https://example.net/ollama:11434"); + new OllamaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.ollama.base-url", "https://example.net/ollama:11434"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.ollama.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.ollama.enabled".formatted(CONFIG_PATH), "false"); - new OllamaBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new OllamaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java index 08225ab6712..fdebd11a2ef 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + import org.junit.jupiter.api.Test; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.mock.env.MockEnvironment; -import java.nio.file.Paths; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -50,17 +51,17 @@ class OpenAiBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new OpenAiBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.openai.api-key", "demo"); - assertThat(properties).containsEntry("spring.ai.openai.base-url", "https://my.openai.example.net"); + new OpenAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.openai.api-key", "demo"); + assertThat(this.properties).containsEntry("spring.ai.openai.base-url", "https://my.openai.example.net"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.openai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.openai.enabled".formatted(CONFIG_PATH), "false"); - new OpenAiBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new OpenAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java index 14f33b9c043..40492754fca 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + import org.junit.jupiter.api.Test; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.mock.env.MockEnvironment; -import java.nio.file.Paths; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -69,27 +70,29 @@ class TanzuBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new TanzuBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.openai.chat.api-key", "demo"); - assertThat(properties).containsEntry("spring.ai.openai.chat.base-url", "https://my.openai.example.net"); - assertThat(properties).containsEntry("spring.ai.openai.chat.options.model", "llava1.6"); - assertThat(properties).containsEntry("spring.ai.openai.embedding.api-key", "demo2"); - assertThat(properties).containsEntry("spring.ai.openai.embedding.base-url", "https://my.openai2.example.net"); - assertThat(properties).containsEntry("spring.ai.openai.embedding.options.model", "text-embed-large"); + new TanzuBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.openai.chat.api-key", "demo"); + assertThat(this.properties).containsEntry("spring.ai.openai.chat.base-url", "https://my.openai.example.net"); + assertThat(this.properties).containsEntry("spring.ai.openai.chat.options.model", "llava1.6"); + assertThat(this.properties).containsEntry("spring.ai.openai.embedding.api-key", "demo2"); + assertThat(this.properties).containsEntry("spring.ai.openai.embedding.base-url", + "https://my.openai2.example.net"); + assertThat(this.properties).containsEntry("spring.ai.openai.embedding.options.model", "text-embed-large"); } @Test void propertiesAreMissingModelCapabilities() { - new TanzuBindingsPropertiesProcessor().process(environment, bindingsMissingModelCapabilities, properties); - assertThat(properties).isEmpty(); + new TanzuBindingsPropertiesProcessor().process(this.environment, this.bindingsMissingModelCapabilities, + this.properties); + assertThat(this.properties).isEmpty(); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.genai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.genai.enabled".formatted(CONFIG_PATH), "false"); - new TanzuBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new TanzuBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java index 9d91a8e843a..f48638ff233 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + import org.junit.jupiter.api.Test; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.mock.env.MockEnvironment; -import java.nio.file.Paths; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -50,18 +51,18 @@ class WeaviateBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new WeaviateBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.vectorstore.weaviate.scheme", "https"); - assertThat(properties).containsEntry("spring.ai.vectorstore.weaviate.host", "example.net:8000"); - assertThat(properties).containsEntry("spring.ai.vectorstore.weaviate.api-key", "demo"); + new WeaviateBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.weaviate.scheme", "https"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.weaviate.host", "example.net:8000"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.weaviate.api-key", "demo"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.weaviate.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.weaviate.enabled".formatted(CONFIG_PATH), "false"); - new WeaviateBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new WeaviateBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-test/pom.xml b/spring-ai-test/pom.xml index 45cd0df17b5..3397a92ae7b 100644 --- a/spring-ai-test/pom.xml +++ b/spring-ai-test/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java index 2f4b01a79d2..ea056cca3f9 100644 --- a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java +++ b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.evaluation; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; +package org.springframework.ai.evaluation; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.model.ChatModel; @@ -32,6 +31,9 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + public class BasicEvaluationTest { private static final Logger logger = LoggerFactory.getLogger(BasicEvaluationTest.class); @@ -56,23 +58,23 @@ protected void evaluateQuestionAndAnswer(String question, String answer, boolean assertThat(answer).isNotNull(); logger.info("Question: " + question); logger.info("Answer:" + answer); - PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, + PromptTemplate userPromptTemplate = new PromptTemplate(this.userEvaluatorResource, Map.of("question", question, "answer", answer)); SystemMessage systemMessage; if (factBased) { - systemMessage = new SystemMessage(qaEvaluatorFactBasedAnswerResource); + systemMessage = new SystemMessage(this.qaEvaluatorFactBasedAnswerResource); } else { - systemMessage = new SystemMessage(qaEvaluatorAccurateAnswerResource); + systemMessage = new SystemMessage(this.qaEvaluatorAccurateAnswerResource); } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiChatModel.call(prompt).getResult().getOutput().getContent(); + String yesOrNo = this.openAiChatModel.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { - SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); + SystemMessage notRelatedSystemMessage = new SystemMessage(this.qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiChatModel.call(prompt).getResult().getOutput().getContent(); + String reasonForFailure = this.openAiChatModel.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else { @@ -81,4 +83,4 @@ protected void evaluateQuestionAndAnswer(String question, String answer, boolean } } -} \ No newline at end of file +} diff --git a/src/checkstyle/checkstyle-header.txt b/src/checkstyle/checkstyle-header.txt new file mode 100644 index 00000000000..9c623668092 --- /dev/null +++ b/src/checkstyle/checkstyle-header.txt @@ -0,0 +1,17 @@ +^\Q/*\E$ +^\Q * Copyright \E20\d\d\-20\d\d\Q the original author or authors.\E$ +^\Q *\E$ +^\Q * Licensed under the Apache License, Version 2.0 (the "License");\E$ +^\Q * you may not use this file except in compliance with the License.\E$ +^\Q * You may obtain a copy of the License at\E$ +^\Q *\E$ +^\Q * https://www.apache.org/licenses/LICENSE-2.0\E$ +^\Q *\E$ +^\Q * Unless required by applicable law or agreed to in writing, software\E$ +^\Q * distributed under the License is distributed on an "AS IS" BASIS,\E$ +^\Q * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\E$ +^\Q * See the License for the specific language governing permissions and\E$ +^\Q * limitations under the License.\E$ +^\Q */\E$ +^$ +^.*$ \ No newline at end of file diff --git a/src/checkstyle/checkstyle-suppressions.xml b/src/checkstyle/checkstyle-suppressions.xml new file mode 100644 index 00000000000..5f78aa52b5f --- /dev/null +++ b/src/checkstyle/checkstyle-suppressions.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + diff --git a/src/checkstyle/checkstyle.xml b/src/checkstyle/checkstyle.xml new file mode 100644 index 00000000000..2b6a846c1e2 --- /dev/null +++ b/src/checkstyle/checkstyle.xml @@ -0,0 +1,185 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml index 0410a18bda8..8c1e57ccdea 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml +++ b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java index cb5f1df3c39..294dffd8596 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,9 +16,6 @@ package org.springframework.ai.vectorstore; -import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; -import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; - import java.util.Collection; import java.util.Map; import java.util.Optional; @@ -29,6 +26,9 @@ import org.springframework.ai.vectorstore.filter.Filter.Key; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; + /** * Converts {@link org.springframework.ai.vectorstore.filter.Filter.Expression} into * Cosmos DB NoSQL API where clauses. @@ -51,7 +51,7 @@ public CosmosDBFilterExpressionConverter(Collection columns) { */ private Optional getMetadataField(String name) { String metadataField = name; - return Optional.ofNullable(metadataFields.get(metadataField)); + return Optional.ofNullable(this.metadataFields.get(metadataField)); } @Override diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java index 54716b142e4..8c469701770 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -24,20 +24,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; - import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosAsyncContainer; import com.azure.cosmos.CosmosAsyncDatabase; @@ -66,10 +52,23 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; - import io.micrometer.observation.ObservationRegistry; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; + /** * @author Theo van Kraay * @author Soby Chacko @@ -81,14 +80,14 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen private final CosmosAsyncClient cosmosClient; - private CosmosAsyncContainer container; - private final EmbeddingModel embeddingModel; private final CosmosDBVectorStoreConfig properties; private final BatchingStrategy batchingStrategy; + private CosmosAsyncContainer container; + public CosmosDBVectorStore(ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient, CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel) { @@ -210,7 +209,7 @@ public void doAdd(List documents) { CosmosItemOperation operation = CosmosBulkOperations .getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId())); return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID - // with the operation + // with the operation }).toList(); try { @@ -233,7 +232,7 @@ public void doAdd(List documents) { String errorMessage = String.format("Duplicate document id: %s", documentId); logger.error(errorMessage); throw new RuntimeException(errorMessage); // Throw an exception - // for status code 409 + // for status code 409 } else { logger.info("Document added with status: {}", statusCode); @@ -307,10 +306,10 @@ public List doSimilaritySearch(SearchRequest request) { if (filterExpression != null) { CosmosDBFilterExpressionConverter filterExpressionConverter = new CosmosDBFilterExpressionConverter( this.properties.getMetadataFieldsList()); // Use the expression - // directly as - // it handles the - // "metadata" - // fields internally + // directly as + // it handles the + // "metadata" + // fields internally String filterQuery = filterExpressionConverter.convertExpression(filterExpression); queryBuilder.append(" AND ").append(filterQuery); } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreConfig.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreConfig.java index 244ee814534..96729cbcfb7 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreConfig.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -51,15 +51,15 @@ public void setVectorStoreThroughput(int vectorStoreThroughput) { this.vectorStoreThroughput = vectorStoreThroughput; } + public String getMetadataFields() { + return this.metadataFields; + } + public void setMetadataFields(String metadataFields) { this.metadataFields = metadataFields; this.metadataFieldsList = List.of(metadataFields.split(",")); } - public String getMetadataFields() { - return this.metadataFields; - } - public List getMetadataFieldsList() { return this.metadataFieldsList; } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java index 6c064444390..d8432fa71d1 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,11 +16,17 @@ package org.springframework.ai.vectorstore; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosClientBuilder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -30,10 +36,6 @@ import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -53,8 +55,8 @@ public class CosmosDBVectorStoreIT { @BeforeEach public void setup() { - contextRunner.run(context -> { - vectorStore = context.getBean(VectorStore.class); + this.contextRunner.run(context -> { + this.vectorStore = context.getBean(VectorStore.class); }); } @@ -66,25 +68,25 @@ public void testAddSearchAndDeleteDocuments() { Document document2 = new Document(UUID.randomUUID().toString(), "Sample content2", Map.of("key2", "value2")); // Add the document to the vector store - vectorStore.add(List.of(document1, document2)); + this.vectorStore.add(List.of(document1, document2)); // create duplicate docs and assert that second one throws exception Document document3 = new Document(document1.getId(), "Sample content3", Map.of("key3", "value3")); - assertThatThrownBy(() -> vectorStore.add(List.of(document3))).isInstanceOf(Exception.class) + assertThatThrownBy(() -> this.vectorStore.add(List.of(document3))).isInstanceOf(Exception.class) .hasMessageContaining("Duplicate document id: " + document1.getId()); // Perform a similarity search - List results = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); // Verify the search results assertThat(results).isNotEmpty(); assertThat(results.get(0).getId()).isEqualTo(document1.getId()); // Remove the documents from the vector store - vectorStore.delete(List.of(document1.getId(), document2.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId())); // Perform a similarity search again - List results2 = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); // Verify the search results assertThat(results2).isEmpty(); @@ -124,16 +126,16 @@ void testSimilaritySearchWithFilter() { Document document3 = new Document("3", "A document about the US", metadata3); Document document4 = new Document("4", "A document about the US", metadata4); - vectorStore.add(List.of(document1, document2, document3, document4)); + this.vectorStore.add(List.of(document1, document2, document3, document4)); FilterExpressionBuilder b = new FilterExpressionBuilder(); - List results = vectorStore.similaritySearch(SearchRequest.query("The World") + List results = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression((b.in("country", "UK", "NL")).build())); assertThat(results).hasSize(2); assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2"); - List results2 = vectorStore.similaritySearch(SearchRequest.query("The World") + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression( b.and(b.or(b.gte("year", 2021), b.eq("country", "NL")), b.ne("city", "Amsterdam")).build())); @@ -141,17 +143,17 @@ void testSimilaritySearchWithFilter() { assertThat(results2).hasSize(1); assertThat(results2).extracting(Document::getId).containsExactlyInAnyOrder("1"); - List results3 = vectorStore.similaritySearch(SearchRequest.query("The World") + List results3 = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression(b.and(b.eq("country", "US"), b.eq("year", 2020)).build())); assertThat(results3).hasSize(1); assertThat(results3).extracting(Document::getId).containsExactlyInAnyOrder("4"); - vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); // Perform a similarity search again - List results4 = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)); + List results4 = this.vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)); // Verify the search results assertThat(results4).isEmpty(); @@ -191,6 +193,7 @@ public EmbeddingModel embeddingModel() { public VectorStoreObservationConvention observationConvention() { // Replace with an actual observation convention or a mock if needed return new VectorStoreObservationConvention() { + }; } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties index 20c6c622002..82882acded4 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spring.ai.vectorstore.cosmosdb.databaseName=db spring.ai.vectorstore.cosmosdb.containerName=container spring.ai.vectorstore.cosmosdb.key=${COSMOSDB_AI_ENDPOINT} diff --git a/vector-stores/spring-ai-azure-store/pom.xml b/vector-stores/spring-ai-azure-store/pom.xml index fa819779c16..25bf7e5f11e 100644 --- a/vector-stores/spring-ai-azure-store/pom.xml +++ b/vector-stores/spring-ai-azure-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java index 2bcb1f3f0eb..ed127ea11cf 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.azure; import java.text.ParseException; diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index 4a5ea34459d..e5ed2c45707 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -23,27 +23,6 @@ import java.util.Optional; import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.model.EmbeddingUtils; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; -import org.springframework.ai.vectorstore.SearchRequest; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.TypeReference; import com.azure.core.util.Context; @@ -63,8 +42,28 @@ import com.azure.search.documents.models.SearchOptions; import com.azure.search.documents.models.VectorSearchOptions; import com.azure.search.documents.models.VectorizedQuery; - import io.micrometer.observation.ObservationRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * Uses Azure Cognitive Search as a backing vector store. Documents can be preloaded into @@ -81,14 +80,14 @@ */ public class AzureVectorStore extends AbstractObservationVectorStore implements InitializingBean { + public static final String DEFAULT_INDEX_NAME = "spring_ai_azure_vector_store"; + private static final Logger logger = LoggerFactory.getLogger(AzureVectorStore.class); private static final String SPRING_AI_VECTOR_CONFIG = "spring-ai-vector-config"; private static final String SPRING_AI_VECTOR_PROFILE = "spring-ai-vector-profile"; - public static final String DEFAULT_INDEX_NAME = "spring_ai_azure_vector_store"; - private static final String ID_FIELD_NAME = "id"; private static final String CONTENT_FIELD_NAME = "content"; @@ -109,16 +108,8 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements private final EmbeddingModel embeddingModel; - private SearchClient searchClient; - private final FilterExpressionConverter filterExpressionConverter; - private int defaultTopK = DEFAULT_TOP_K; - - private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD; - - private String indexName = DEFAULT_INDEX_NAME; - private final boolean initializeSchema; private final BatchingStrategy batchingStrategy; @@ -134,32 +125,13 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements */ private final List filterMetadataFields; - public record MetadataField(String name, SearchFieldDataType fieldType) { - - public static MetadataField text(String name) { - return new MetadataField(name, SearchFieldDataType.STRING); - } - - public static MetadataField int32(String name) { - return new MetadataField(name, SearchFieldDataType.INT32); - } - - public static MetadataField int64(String name) { - return new MetadataField(name, SearchFieldDataType.INT64); - } + private SearchClient searchClient; - public static MetadataField decimal(String name) { - return new MetadataField(name, SearchFieldDataType.DOUBLE); - } + private int defaultTopK = DEFAULT_TOP_K; - public static MetadataField bool(String name) { - return new MetadataField(name, SearchFieldDataType.BOOLEAN); - } + private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD; - public static MetadataField date(String name) { - return new MetadataField(name, SearchFieldDataType.DATE_TIME_OFFSET); - } - } + private String indexName = DEFAULT_INDEX_NAME; /** * Constructs a new AzureCognitiveSearchVectorStore. @@ -320,7 +292,7 @@ public List doSimilaritySearch(SearchRequest request) { Assert.notNull(request, "The search request must not be null."); - var searchEmbedding = embeddingModel.embed(request.getQuery()); + var searchEmbedding = this.embeddingModel.embed(request.getQuery()); final var vectorQuery = new VectorizedQuery(EmbeddingUtils.toList(searchEmbedding)) .setKNearestNeighborsCount(request.getTopK()) @@ -336,7 +308,7 @@ public List doSimilaritySearch(SearchRequest request) { searchOptions.setFilter(oDataFilter); } - final var searchResults = searchClient.search(null, searchOptions, Context.NONE); + final var searchResults = this.searchClient.search(null, searchOptions, Context.NONE); return searchResults.stream() .filter(result -> result.getScore() >= request.getSimilarityThreshold()) @@ -346,6 +318,7 @@ public List doSimilaritySearch(SearchRequest request) { Map metadata = (StringUtils.hasText(entry.metadata())) ? JSONObject.parseObject(entry.metadata(), new TypeReference>() { + }) : Map.of(); metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore()); @@ -359,12 +332,6 @@ public List doSimilaritySearch(SearchRequest request) { .collect(Collectors.toList()); } - /** - * Internal data structure for retrieving and storing documents. - */ - private record AzureSearchDocument(String id, String content, List embedding, String metadata) { - } - @Override public void afterPropertiesSet() throws Exception { @@ -426,4 +393,39 @@ public Builder createObservationContextBuilder(String operationName) { .withSimilarityMetric(this.initializeSchema ? VectorStoreSimilarityMetric.COSINE.value() : null); } + public record MetadataField(String name, SearchFieldDataType fieldType) { + + public static MetadataField text(String name) { + return new MetadataField(name, SearchFieldDataType.STRING); + } + + public static MetadataField int32(String name) { + return new MetadataField(name, SearchFieldDataType.INT32); + } + + public static MetadataField int64(String name) { + return new MetadataField(name, SearchFieldDataType.INT64); + } + + public static MetadataField decimal(String name) { + return new MetadataField(name, SearchFieldDataType.DOUBLE); + } + + public static MetadataField bool(String name) { + return new MetadataField(name, SearchFieldDataType.BOOLEAN); + } + + public static MetadataField date(String name) { + return new MetadataField(name, SearchFieldDataType.DATE_TIME_OFFSET); + } + + } + + /** + * Internal data structure for retrieving and storing documents. + */ + private record AzureSearchDocument(String id, String content, List embedding, String metadata) { + + } + } diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java index 65d67fcc651..2bd63829050 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.azure; import java.util.Date; @@ -204,4 +205,4 @@ public void testComplexIdentifiers() { """); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index 9b980817720..03418bb2028 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.azure; import java.io.IOException; @@ -56,14 +57,14 @@ @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+") public class AzureVectorStoreIT { + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -71,14 +72,24 @@ public static void beforeAll() { Awaitility.setDefaultTimeout(Duration.ofMinutes(1)); } + private static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)), @@ -88,14 +99,14 @@ public void addAndSearchTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); @@ -105,7 +116,7 @@ public void addAndSearchTest() { @Test public void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner", @@ -194,7 +205,7 @@ public void searchWithFilters() throws InterruptedException { @Test public void documentUpdateTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -247,11 +258,11 @@ public void documentUpdateTest() { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore @@ -272,13 +283,13 @@ public void searchThresholdTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); }); @@ -309,14 +320,4 @@ public EmbeddingModel embeddingModel() { } - private static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java index 10c876db0d2..6ce752d5fb0 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.azure; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore.azure; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,10 +23,17 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.search.documents.indexes.SearchIndexClient; +import com.azure.search.documents.indexes.SearchIndexClientBuilder; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -47,13 +53,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.search.documents.indexes.SearchIndexClient; -import com.azure.search.documents.indexes.SearchIndexClientBuilder; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation AbstractObservationVectorStore in @@ -66,6 +66,9 @@ @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+") public class AzureVectorStoreObservationIT { + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -81,9 +84,6 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -94,13 +94,13 @@ public static void beforeAll() { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-cassandra-store/pom.xml b/vector-stores/spring-ai-cassandra-store/pom.xml index c676627b97a..1032f363547 100644 --- a/vector-stores/spring-ai-cassandra-store/pom.xml +++ b/vector-stores/spring-ai-cassandra-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/cassandra/SchemaUtil.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/cassandra/SchemaUtil.java index f945b9db8bf..11ca9efe005 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/cassandra/SchemaUtil.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/cassandra/SchemaUtil.java @@ -1,30 +1,29 @@ /* + * Copyright 2023-2024 the original author or authors. + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * See the NOTICE file distributed with this work for additional information - * regarding copyright ownership. */ + package org.springframework.ai.cassandra; +import java.time.Duration; + import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.time.Duration; - /** * @author Mick Semb Wever * @since 1.0.0 diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java index 7dc82f37012..c12c81dd65a 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.memory; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; + import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; import com.datastax.oss.driver.api.core.cql.PreparedStatement; import com.datastax.oss.driver.api.core.cql.Row; @@ -31,18 +37,13 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicLong; - /** * Create a CassandraChatMemory like -CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); - + CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); + * * For example @see org.springframework.ai.chat.memory.CassandraChatMemory - * + * * @author Mick Semb Wever * @since 1.0.0 */ @@ -54,10 +55,6 @@ public final class CassandraChatMemory implements ChatMemory { private final PreparedStatement addUserStmt, addAssistantStmt, getStmt, deleteStmt; - public static CassandraChatMemory create(CassandraChatMemoryConfig conf) { - return new CassandraChatMemory(conf); - } - public CassandraChatMemory(CassandraChatMemoryConfig config) { this.conf = config; this.conf.ensureSchemaExists(); @@ -67,6 +64,10 @@ public CassandraChatMemory(CassandraChatMemoryConfig config) { this.deleteStmt = prepareDeleteStmt(); } + public static CassandraChatMemory create(CassandraChatMemoryConfig conf) { + return new CassandraChatMemory(conf); + } + @Override public void add(String conversationId, List messages) { final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli()); @@ -90,8 +91,8 @@ public void add(String sessionId, Message msg) { PreparedStatement stmt; switch (msg.getMessageType()) { - case USER -> stmt = addUserStmt; - case ASSISTANT -> stmt = addAssistantStmt; + case USER -> stmt = this.addUserStmt; + case ASSISTANT -> stmt = this.addAssistantStmt; default -> throw new IllegalArgumentException("Cant add type " + msg); } @@ -115,7 +116,7 @@ public void add(String sessionId, Message msg) { public void clear(String sessionId) { List primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId); - BoundStatementBuilder builder = deleteStmt.boundStatementBuilder(); + BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder(); for (int k = 0; k < primaryKeys.size(); ++k) { SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); @@ -129,7 +130,7 @@ public void clear(String sessionId) { public List get(String sessionId, int lastN) { List primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId); - BoundStatementBuilder builder = getStmt.boundStatementBuilder().setInt("lastN", lastN); + BoundStatementBuilder builder = this.getStmt.boundStatementBuilder().setInt("lastN", lastN); for (int k = 0; k < primaryKeys.size(); ++k) { SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java index ab046e2eca6..3c9f329b6ec 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.memory; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Function; + import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.cql.SimpleStatement; @@ -32,42 +40,17 @@ import com.datastax.oss.driver.api.querybuilder.schema.CreateTableWithOptions; import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.cassandra.SchemaUtil; -import java.net.InetSocketAddress; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.function.Function; - /** * @author Mick Semb Wever * @since 1.0.0 */ public final class CassandraChatMemoryConfig { - private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryConfig.class); - - record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys) { - } - - public record SchemaColumn(String name, DataType type) { - - public GenericType javaType() { - return CodecRegistry.DEFAULT.codecFor(type).getJavaType(); - } - } - - /** Given a string sessionId, return the value for each primary key column. */ - public interface SessionIdToPrimaryKeysTranslator extends Function> { - - } - public static final String DEFAULT_KEYSPACE_NAME = "springframework"; public static final String DEFAULT_TABLE_NAME = "ai_chat_memory"; @@ -82,6 +65,8 @@ public interface SessionIdToPrimaryKeysTranslator extends Function> { + + } + + record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys) { + + } + + public record SchemaColumn(String name, DataType type) { + + public GenericType javaType() { + return CodecRegistry.DEFAULT.codecFor(this.type).getJavaType(); + } + + } + public static class Builder { private CqlSession session = null; @@ -226,14 +318,14 @@ public Builder withChatExchangeToPrimaryKeyTranslator(SessionIdToPrimaryKeysTran public CassandraChatMemoryConfig build() { - int primaryKeyColumns = partitionKeys.size() + clusteringKeys.size(); + int primaryKeyColumns = this.partitionKeys.size() + this.clusteringKeys.size(); int primaryKeysToBind = this.primaryKeyTranslator.apply(UUID.randomUUID().toString()).size(); Preconditions.checkArgument(primaryKeyColumns == primaryKeysToBind + 1, "The primaryKeyTranslator must always return one less element than the number of primary keys in total. The last clustering key remains undefined, expecting to be the timestamp for messages within sessionId. The sessionId can map to any primary key column (though it should map to a partition key column)."); Preconditions.checkArgument( - clusteringKeys.get(clusteringKeys.size() - 1).name().equals(DEFAULT_EXCHANGE_ID_NAME), + this.clusteringKeys.get(this.clusteringKeys.size() - 1).name().equals(DEFAULT_EXCHANGE_ID_NAME), "last clustering key must be the exchangeIdColumn"); return new CassandraChatMemoryConfig(this); @@ -241,92 +333,4 @@ public CassandraChatMemoryConfig build() { } - void ensureSchemaExists() { - if (!disallowSchemaChanges) { - SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); - ensureTableExists(); - ensureTableColumnsExist(); - SchemaUtil.checkSchemaAgreement(this.session); - } - else { - checkSchemaValid(); - } - } - - void checkSchemaValid() { - - Preconditions.checkState(session.getMetadata().getKeyspace(this.schema.keyspace).isPresent(), - "keyspace %s does not exist", this.schema.keyspace); - - Preconditions.checkState( - session.getMetadata().getKeyspace(this.schema.keyspace).get().getTable(this.schema.table).isPresent(), - "table %s does not exist"); - - TableMetadata tableMetadata = session.getMetadata() - .getKeyspace(this.schema.keyspace) - .get() - .getTable(this.schema.table) - .get(); - - Preconditions.checkState(tableMetadata.getColumn(this.assistantColumn).isPresent(), "column %s does not exist", - this.assistantColumn); - - Preconditions.checkState(tableMetadata.getColumn(this.userColumn).isPresent(), "column %s does not exist", - this.userColumn); - } - - private void ensureTableExists() { - if (session.getMetadata().getKeyspace(schema.keyspace).get().getTable(this.schema.table).isEmpty()) { - CreateTable createTable = null; - - CreateTableStart createTableStart = SchemaBuilder.createTable(this.schema.keyspace, this.schema.table) - .ifNotExists(); - - for (SchemaColumn partitionKey : this.schema.partitionKeys) { - createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, - partitionKey.type); - } - for (SchemaColumn clusteringKey : this.schema.clusteringKeys) { - createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); - } - - String lastClusteringColumn = this.schema.clusteringKeys.get(this.schema.clusteringKeys.size() - 1).name(); - - CreateTableWithOptions createTableWithOptions = createTable.withColumn(this.userColumn, DataTypes.TEXT) - .withClusteringOrder(lastClusteringColumn, ClusteringOrder.DESC) - // TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() is available - .withOption("compaction", Map.of("class", "UnifiedCompactionStrategy")); - - if (null != this.timeToLiveSeconds) { - createTableWithOptions = createTableWithOptions.withDefaultTimeToLiveSeconds(this.timeToLiveSeconds); - } - this.session.execute(createTableWithOptions.build()); - } - } - - private void ensureTableColumnsExist() { - - TableMetadata tableMetadata = this.session.getMetadata() - .getKeyspace(this.schema.keyspace()) - .get() - .getTable(this.schema.table()) - .get(); - - boolean addAssistantColumn = tableMetadata.getColumn(this.assistantColumn).isEmpty(); - boolean addUserColumn = tableMetadata.getColumn(this.userColumn).isEmpty(); - - if (addAssistantColumn || addUserColumn) { - AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table()); - if (addAssistantColumn) { - alterTable = alterTable.addColumn(this.assistantColumn, DataTypes.TEXT); - } - if (addUserColumn) { - alterTable = alterTable.addColumn(this.userColumn, DataTypes.TEXT); - } - SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); - logger.debug("Executing {}", stmt.getQuery()); - this.session.execute(stmt); - } - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java index aef2a56b296..ddb3104092e 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry; @@ -26,12 +33,6 @@ import org.springframework.ai.vectorstore.filter.Filter.Value; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; -import java.util.Collection; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; - /** * Converts {@link org.springframework.ai.vectorstore.filter.Filter.Expression} into CQL * where clauses. @@ -49,6 +50,24 @@ public CassandraFilterExpressionConverter(Collection columns) { .collect(Collectors.toMap((c) -> c.getName().asInternal(), Function.identity())); } + private static void doOperand(ExpressionType type, StringBuilder context) { + switch (type) { + case EQ -> context.append(" = "); + case NE -> context.append(" != "); + case GT -> context.append(" > "); + case GTE -> context.append(" >= "); + case IN -> context.append(" IN "); + case LT -> context.append(" < "); + case LTE -> context.append(" <= "); + // TODO SAI supports collections + // reach out to mck@apache.org if you'd like these implemented + // case CONTAINS -> context.append(" CONTAINS "); + // case CONTAINS_KEY -> context.append(" CONTAINS_KEY "); + default -> throw new UnsupportedOperationException( + String.format("Expression type %s not yet implemented. Patches welcome.", type)); + } + } + @Override protected void doKey(Key key, StringBuilder context) { String keyName = key.key(); @@ -68,24 +87,6 @@ protected void doExpression(Filter.Expression expression, StringBuilder context) } } - private static void doOperand(ExpressionType type, StringBuilder context) { - switch (type) { - case EQ -> context.append(" = "); - case NE -> context.append(" != "); - case GT -> context.append(" > "); - case GTE -> context.append(" >= "); - case IN -> context.append(" IN "); - case LT -> context.append(" < "); - case LTE -> context.append(" <= "); - // TODO SAI supports collections - // reach out to mck@apache.org if you'd like these implemented - // case CONTAINS -> context.append(" CONTAINS "); - // case CONTAINS_KEY -> context.append(" CONTAINS_KEY "); - default -> throw new UnsupportedOperationException( - String.format("Expression type %s not yet implemented. Patches welcome.", type)); - } - } - private void doBinaryOperation(String operator, Filter.Expression expression, StringBuilder context) { this.convertOperand(expression.left(), context); context.append(operator); diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 35bb49420f3..349463646f6 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,17 @@ package org.springframework.ai.vectorstore; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; import com.datastax.oss.driver.api.core.cql.PreparedStatement; @@ -29,9 +40,7 @@ import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; - import io.micrometer.observation.ObservationRegistry; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,17 +59,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; - /** * The CassandraVectorStore is for managing and querying vector data in an Apache * Cassandra db. It offers functionalities like adding, deleting, and performing @@ -108,16 +106,6 @@ */ public class CassandraVectorStore extends AbstractObservationVectorStore implements AutoCloseable { - /** - * Indexes are automatically created with COSINE. This can be changed manually via - * cqlsh - */ - public enum Similarity { - - COSINE, DOT_PRODUCT, EUCLIDEAN; - - } - public static final String SIMILARITY_FIELD_NAME = "similarity_score"; public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates"; @@ -128,6 +116,10 @@ public enum Similarity { private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class); + private static Map SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE, + VectorStoreSimilarityMetric.COSINE, Similarity.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, + Similarity.DOT_PRODUCT, VectorStoreSimilarityMetric.DOT); + private final CassandraVectorStoreConfig conf; private final EmbeddingModel embeddingModel; @@ -177,6 +169,15 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embe this.batchingStrategy = batchingStrategy; } + private static Float[] toFloatArray(float[] embedding) { + Float[] embeddingFloat = new Float[embedding.length]; + int i = 0; + for (Float d : embedding) { + embeddingFloat[i++] = d.floatValue(); + } + return embeddingFloat; + } + @Override public void doAdd(List documents) { var futures = new CompletableFuture[documents.size()]; @@ -275,7 +276,7 @@ public void close() throws Exception { } void checkSchemaValid() { - this.conf.checkSchemaValid(embeddingModel.dimensions()); + this.conf.checkSchemaValid(this.embeddingModel.dimensions()); } private Similarity getIndexSimilarity(TableMetadata metadata) { @@ -289,7 +290,7 @@ private Similarity getIndexSimilarity(TableMetadata metadata) { private PreparedStatement prepareDeleteStatement() { Delete stmt = null; - DeleteSelection stmtStart = QueryBuilder.deleteFrom(conf.schema.keyspace(), conf.schema.table()); + DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table()); for (var c : this.conf.schema.partitionKeys()) { stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); @@ -344,7 +345,7 @@ private String similaritySearchStatement() { String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase()) .append('(') - .append(conf.schema.embedding()) + .append(this.conf.schema.embedding()) .append(",?)") .toString(); @@ -377,15 +378,6 @@ private String getDocumentId(Row row) { return this.conf.primaryKeyTranslator.apply(primaryKeyValues); } - private static Float[] toFloatArray(float[] embedding) { - Float[] embeddingFloat = new Float[embedding.length]; - int i = 0; - for (Float d : embedding) { - embeddingFloat[i++] = d.floatValue(); - } - return embeddingFloat; - } - @Override public Builder createObservationContextBuilder(String operationName) { return VectorStoreObservationContext.builder(VectorStoreProvider.CASSANDRA.value(), operationName) @@ -395,10 +387,6 @@ public Builder createObservationContextBuilder(String operationName) { .withSimilarityMetric(getSimilarityMetric()); } - private static Map SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE, - VectorStoreSimilarityMetric.COSINE, Similarity.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, - Similarity.DOT_PRODUCT, VectorStoreSimilarityMetric.DOT); - private String getSimilarityMetric() { if (!SIMILARITY_TYPE_MAPPING.containsKey(this.similarity)) { return this.similarity.name(); @@ -406,4 +394,14 @@ private String getSimilarityMetric() { return SIMILARITY_TYPE_MAPPING.get(this.similarity).value(); } + /** + * Indexes are automatically created with COSINE. This can be changed manually via + * cqlsh + */ + public enum Similarity { + + COSINE, DOT_PRODUCT, EUCLIDEAN; + + } + } diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java index e50e8008c5b..65bcba011de 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.function.Function; +import java.util.stream.Stream; + import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.cql.SimpleStatement; @@ -32,23 +44,11 @@ import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart; import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.lang.Nullable; import org.springframework.ai.cassandra.SchemaUtil; - -import java.net.InetSocketAddress; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import java.util.function.Function; -import java.util.stream.Stream; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import org.springframework.lang.Nullable; /** * Configuration for the Cassandra vector store. @@ -84,51 +84,6 @@ public class CassandraVectorStoreConfig implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStoreConfig.class); - record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys, - String content, String embedding, String index, Set metadataColumns) { - - } - - public record SchemaColumn(String name, DataType type, SchemaColumnTags... tags) { - public SchemaColumn(String name, DataType type) { - this(name, type, new SchemaColumnTags[0]); - } - - public GenericType javaType() { - return CodecRegistry.DEFAULT.codecFor(type).getJavaType(); - } - - public boolean indexed() { - for (SchemaColumnTags t : tags) { - if (SchemaColumnTags.INDEXED == t) { - return true; - } - } - return false; - } - } - - public enum SchemaColumnTags { - - INDEXED - - } - - /** - * Given a string document id, return the value for each primary key column. - * - * It is a requirement that an empty {@code List} returns an example formatted - * id - */ - public interface DocumentIdTranslator extends Function> { - - } - - /** Given a list of primary key column values, return the document id. */ - public interface PrimaryKeyTranslator extends Function, String> { - - } - final CqlSession session; final Schema schema; @@ -181,6 +136,232 @@ void dropKeyspace() { this.session.execute(SchemaBuilder.dropKeyspace(this.schema.keyspace).ifExists().build()); } + void ensureSchemaExists(int vectorDimension) { + if (!this.disallowSchemaChanges) { + SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); + ensureTableExists(vectorDimension); + ensureTableColumnsExist(vectorDimension); + ensureIndexesExists(); + SchemaUtil.checkSchemaAgreement(this.session); + } + else { + checkSchemaValid(vectorDimension); + } + } + + void checkSchemaValid(int vectorDimension) { + + Preconditions.checkState(this.session.getMetadata().getKeyspace(this.schema.keyspace).isPresent(), + "keyspace %s does not exist", this.schema.keyspace); + + Preconditions.checkState(this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .isPresent(), "table %s does not exist"); + + TableMetadata tableMetadata = this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .get(); + + Preconditions.checkState(tableMetadata.getColumn(this.schema.content).isPresent(), "column %s does not exist", + this.schema.content); + + Preconditions.checkState(tableMetadata.getColumn(this.schema.embedding).isPresent(), "column %s does not exist", + this.schema.embedding); + + for (SchemaColumn m : this.schema.metadataColumns) { + Optional column = tableMetadata.getColumn(m.name()); + Preconditions.checkState(column.isPresent(), "column %s does not exist", m.name()); + + Preconditions.checkArgument(column.get().getType().equals(m.type()), + "Mismatching type on metadata column %s of %s vs %s", m.name(), column.get().getType(), m.type()); + + if (m.indexed()) { + Preconditions.checkState( + tableMetadata.getIndexes().values().stream().anyMatch((i) -> i.getTarget().equals(m.name())), + "index %s does not exist", m.name()); + } + } + + } + + private void ensureIndexesExists() { + { + SimpleStatement indexStmt = SchemaBuilder.createIndex(this.schema.index) + .ifNotExists() + .custom("StorageAttachedIndex") + .onTable(this.schema.keyspace, this.schema.table) + .andColumn(this.schema.embedding) + .build(); + + logger.debug("Executing {}", indexStmt.getQuery()); + this.session.execute(indexStmt); + } + Stream + .concat(this.schema.partitionKeys.stream(), + Stream.concat(this.schema.clusteringKeys.stream(), this.schema.metadataColumns.stream())) + .filter((cs) -> cs.indexed()) + .forEach((metadata) -> { + + SimpleStatement indexStmt = SchemaBuilder.createIndex(String.format("%s_idx", metadata.name())) + .ifNotExists() + .custom("StorageAttachedIndex") + .onTable(this.schema.keyspace, this.schema.table) + .andColumn(metadata.name()) + .build(); + + logger.debug("Executing {}", indexStmt.getQuery()); + this.session.execute(indexStmt); + }); + } + + private void ensureTableExists(int vectorDimension) { + if (this.session.getMetadata().getKeyspace(this.schema.keyspace).get().getTable(this.schema.table).isEmpty()) { + + CreateTable createTable = null; + + CreateTableStart createTableStart = SchemaBuilder.createTable(this.schema.keyspace, this.schema.table) + .ifNotExists(); + + for (SchemaColumn partitionKey : this.schema.partitionKeys) { + createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, + partitionKey.type); + } + for (SchemaColumn clusteringKey : this.schema.clusteringKeys) { + createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); + } + + createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT); + + for (SchemaColumn metadata : this.schema.metadataColumns) { + createTable = createTable.withColumn(metadata.name(), metadata.type()); + } + + // https://datastax-oss.atlassian.net/browse/JAVA-3118 + // .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT, + // vectorDimension)); + + StringBuilder tableStmt = new StringBuilder(createTable.asCql()); + tableStmt.setLength(tableStmt.length() - 1); + tableStmt.append(',') + .append(this.schema.embedding) + .append(" vector)"); + logger.debug("Executing {}", tableStmt.toString()); + this.session.execute(tableStmt.toString()); + } + } + + private void ensureTableColumnsExist(int vectorDimension) { + + TableMetadata tableMetadata = this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .get(); + + Set newColumns = new HashSet<>(); + boolean addContent = tableMetadata.getColumn(this.schema.content).isEmpty(); + boolean addEmbedding = tableMetadata.getColumn(this.schema.embedding).isEmpty(); + + for (SchemaColumn metadata : this.schema.metadataColumns) { + Optional column = tableMetadata.getColumn(metadata.name()); + if (column.isPresent()) { + + Preconditions.checkArgument(column.get().getType().equals(metadata.type()), + "Cannot change type on metadata column %s from %s to %s", metadata.name(), + column.get().getType(), metadata.type()); + } + else { + newColumns.add(metadata); + } + } + + if (!newColumns.isEmpty() || addContent || addEmbedding) { + AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace, this.schema.table); + for (SchemaColumn metadata : newColumns) { + alterTable = alterTable.addColumn(metadata.name(), metadata.type()); + } + if (addContent) { + alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT); + } + if (addEmbedding) { + // special case for embedding column, bc JAVA-3118, as above + StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql()); + if (newColumns.isEmpty() && !addContent) { + alterTableStmt.append(" ADD ("); + } + else { + alterTableStmt.setLength(alterTableStmt.length() - 1); + alterTableStmt.append(','); + } + alterTableStmt.append(this.schema.embedding) + .append(" vector)"); + + logger.debug("Executing {}", alterTableStmt.toString()); + this.session.execute(alterTableStmt.toString()); + } + else { + SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); + logger.debug("Executing {}", stmt.getQuery()); + this.session.execute(stmt); + } + } + } + + public enum SchemaColumnTags { + + INDEXED + + } + + /** + * Given a string document id, return the value for each primary key column. + * + * It is a requirement that an empty {@code List} returns an example formatted + * id + */ + public interface DocumentIdTranslator extends Function> { + + } + + /** Given a list of primary key column values, return the document id. */ + public interface PrimaryKeyTranslator extends Function, String> { + + } + + record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys, + String content, String embedding, String index, Set metadataColumns) { + + } + + public record SchemaColumn(String name, DataType type, SchemaColumnTags... tags) { + + public SchemaColumn(String name, DataType type) { + this(name, type, new SchemaColumnTags[0]); + } + + public GenericType javaType() { + return CodecRegistry.DEFAULT.codecFor(this.type).getJavaType(); + } + + public boolean indexed() { + for (SchemaColumnTags t : this.tags) { + if (SchemaColumnTags.INDEXED == t) { + return true; + } + } + return false; + } + + } + public static class Builder { private CqlSession session = null; @@ -383,183 +564,4 @@ public CassandraVectorStoreConfig build() { } - void ensureSchemaExists(int vectorDimension) { - if (!this.disallowSchemaChanges) { - SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); - ensureTableExists(vectorDimension); - ensureTableColumnsExist(vectorDimension); - ensureIndexesExists(); - SchemaUtil.checkSchemaAgreement(session); - } - else { - checkSchemaValid(vectorDimension); - } - } - - void checkSchemaValid(int vectorDimension) { - - Preconditions.checkState(this.session.getMetadata().getKeyspace(this.schema.keyspace).isPresent(), - "keyspace %s does not exist", this.schema.keyspace); - - Preconditions.checkState(this.session.getMetadata() - .getKeyspace(this.schema.keyspace) - .get() - .getTable(this.schema.table) - .isPresent(), "table %s does not exist"); - - TableMetadata tableMetadata = this.session.getMetadata() - .getKeyspace(this.schema.keyspace) - .get() - .getTable(this.schema.table) - .get(); - - Preconditions.checkState(tableMetadata.getColumn(this.schema.content).isPresent(), "column %s does not exist", - this.schema.content); - - Preconditions.checkState(tableMetadata.getColumn(this.schema.embedding).isPresent(), "column %s does not exist", - this.schema.embedding); - - for (SchemaColumn m : this.schema.metadataColumns) { - Optional column = tableMetadata.getColumn(m.name()); - Preconditions.checkState(column.isPresent(), "column %s does not exist", m.name()); - - Preconditions.checkArgument(column.get().getType().equals(m.type()), - "Mismatching type on metadata column %s of %s vs %s", m.name(), column.get().getType(), m.type()); - - if (m.indexed()) { - Preconditions.checkState( - tableMetadata.getIndexes().values().stream().anyMatch((i) -> i.getTarget().equals(m.name())), - "index %s does not exist", m.name()); - } - } - - } - - private void ensureIndexesExists() { - { - SimpleStatement indexStmt = SchemaBuilder.createIndex(this.schema.index) - .ifNotExists() - .custom("StorageAttachedIndex") - .onTable(this.schema.keyspace, this.schema.table) - .andColumn(this.schema.embedding) - .build(); - - logger.debug("Executing {}", indexStmt.getQuery()); - this.session.execute(indexStmt); - } - Stream - .concat(this.schema.partitionKeys.stream(), - Stream.concat(this.schema.clusteringKeys.stream(), this.schema.metadataColumns.stream())) - .filter((cs) -> cs.indexed()) - .forEach((metadata) -> { - - SimpleStatement indexStmt = SchemaBuilder.createIndex(String.format("%s_idx", metadata.name())) - .ifNotExists() - .custom("StorageAttachedIndex") - .onTable(this.schema.keyspace, this.schema.table) - .andColumn(metadata.name()) - .build(); - - logger.debug("Executing {}", indexStmt.getQuery()); - this.session.execute(indexStmt); - }); - } - - private void ensureTableExists(int vectorDimension) { - if (this.session.getMetadata().getKeyspace(this.schema.keyspace).get().getTable(this.schema.table).isEmpty()) { - - CreateTable createTable = null; - - CreateTableStart createTableStart = SchemaBuilder.createTable(this.schema.keyspace, this.schema.table) - .ifNotExists(); - - for (SchemaColumn partitionKey : this.schema.partitionKeys) { - createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, - partitionKey.type); - } - for (SchemaColumn clusteringKey : this.schema.clusteringKeys) { - createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); - } - - createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT); - - for (SchemaColumn metadata : this.schema.metadataColumns) { - createTable = createTable.withColumn(metadata.name(), metadata.type()); - } - - // https://datastax-oss.atlassian.net/browse/JAVA-3118 - // .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT, - // vectorDimension)); - - StringBuilder tableStmt = new StringBuilder(createTable.asCql()); - tableStmt.setLength(tableStmt.length() - 1); - tableStmt.append(',') - .append(this.schema.embedding) - .append(" vector)"); - logger.debug("Executing {}", tableStmt.toString()); - this.session.execute(tableStmt.toString()); - } - } - - private void ensureTableColumnsExist(int vectorDimension) { - - TableMetadata tableMetadata = this.session.getMetadata() - .getKeyspace(this.schema.keyspace) - .get() - .getTable(this.schema.table) - .get(); - - Set newColumns = new HashSet<>(); - boolean addContent = tableMetadata.getColumn(this.schema.content).isEmpty(); - boolean addEmbedding = tableMetadata.getColumn(this.schema.embedding).isEmpty(); - - for (SchemaColumn metadata : this.schema.metadataColumns) { - Optional column = tableMetadata.getColumn(metadata.name()); - if (column.isPresent()) { - - Preconditions.checkArgument(column.get().getType().equals(metadata.type()), - "Cannot change type on metadata column %s from %s to %s", metadata.name(), - column.get().getType(), metadata.type()); - } - else { - newColumns.add(metadata); - } - } - - if (!newColumns.isEmpty() || addContent || addEmbedding) { - AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace, this.schema.table); - for (SchemaColumn metadata : newColumns) { - alterTable = alterTable.addColumn(metadata.name(), metadata.type()); - } - if (addContent) { - alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT); - } - if (addEmbedding) { - // special case for embedding column, bc JAVA-3118, as above - StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql()); - if (newColumns.isEmpty() && !addContent) { - alterTableStmt.append(" ADD ("); - } - else { - alterTableStmt.setLength(alterTableStmt.length() - 1); - alterTableStmt.append(','); - } - alterTableStmt.append(this.schema.embedding) - .append(" vector)"); - - logger.debug("Executing {}", alterTableStmt.toString()); - this.session.execute(alterTableStmt.toString()); - } - else { - SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); - logger.debug("Executing {}", stmt.getQuery()); - this.session.execute(stmt); - } - } - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java index 8750aa3fe1d..cc70cd97ad9 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java index 9e64261132b..802b046b536 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.memory; import java.time.Duration; @@ -21,11 +22,11 @@ import com.datastax.oss.driver.api.core.CqlSessionBuilder; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.springframework.ai.CassandraImage; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.springframework.ai.CassandraImage; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java index 9cf501c2002..21d07e4be1b 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.Collection; @@ -20,8 +21,8 @@ import java.util.List; import java.util.Set; -import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.internal.core.metadata.schema.DefaultColumnMetadata; import org.junit.jupiter.api.Assertions; @@ -47,6 +48,16 @@ */ class CassandraFilterExpressionConverterTests { + private static final CqlIdentifier T = CqlIdentifier.fromInternal("test"); + + private static final Collection COLUMNS = Set.of( + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("id"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("content"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("country"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("genre"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("drama"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("year"), DataTypes.SMALLINT, false)); + @Test void testEQOnPartition() { @@ -199,14 +210,4 @@ void testComplexIdentifiers() { assertThat(vectorExpr).isEqualTo("\"'country 1 2 3'\" = 'BG'"); } - private static final CqlIdentifier T = CqlIdentifier.fromInternal("test"); - - private static final Collection COLUMNS = Set.of( - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("id"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("content"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("country"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("genre"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("drama"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("year"), DataTypes.SMALLINT, false)); - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java index 177c50e5393..b7684bcc6f6 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -36,12 +37,12 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.CassandraImage; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils; +import org.springframework.ai.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -91,6 +92,64 @@ class CassandraRichSchemaVectorStoreIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(TestApplication.class); + static CassandraVectorStoreConfig.Builder storeBuilder(ApplicationContext context, + List columnOverrides) throws IOException { + + Optional wikiOverride = columnOverrides.stream() + .filter((f) -> "wiki".equals(f.name())) + .findFirst(); + + Optional langOverride = columnOverrides.stream() + .filter((f) -> "language".equals(f.name())) + .findFirst(); + + Optional titleOverride = columnOverrides.stream() + .filter((f) -> "title".equals(f.name())) + .findFirst(); + + Optional chunkNoOverride = columnOverrides.stream() + .filter((f) -> "chunk_no".equals(f.name())) + .findFirst(); + + SchemaColumn wikiSC = wikiOverride.orElse(new SchemaColumn("wiki", DataTypes.TEXT)); + SchemaColumn langSC = langOverride.orElse(new SchemaColumn("language", DataTypes.TEXT)); + SchemaColumn titleSC = titleOverride.orElse(new SchemaColumn("title", DataTypes.TEXT)); + SchemaColumn chunkNoSC = chunkNoOverride.orElse(new SchemaColumn("chunk_no", DataTypes.INT)); + + List partitionKeys = List.of(wikiSC, langSC, titleSC); + List clusteringKeys = List.of(chunkNoSC); + + CassandraVectorStoreConfig.Builder builder = CassandraVectorStoreConfig.builder() + .withCqlSession(context.getBean(CqlSession.class)) + .withKeyspaceName("test_wikidata") + .withTableName("articles") + .withPartitionKeys(partitionKeys) + .withClusteringKeys(clusteringKeys) + .withContentColumnName("body") + .withEmbeddingColumnName("all_minilm_l6_v2_embedding") + .withIndexName("all_minilm_l6_v2_ann") + + .addMetadataColumns(new SchemaColumn("revision", DataTypes.INT), + new SchemaColumn("id", DataTypes.INT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED)) + + // this store uses '§¶' as a deliminator in the document id between db columns + // 'title' and 'chunk_no' + .withPrimaryKeyTranslator((List primaryKeys) -> { + if (primaryKeys.isEmpty()) { + return "test§¶0"; + } + return format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); + }) + .withDocumentIdTranslator((id) -> { + String[] parts = id.split("§¶"); + String title = parts[0]; + int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; + return List.of("simplewiki", "en", title, chunk_no); + }); + + return builder; + } + @Test void ensureSchemaCreation() { this.contextRunner.run(context -> { @@ -157,7 +216,7 @@ void ensureSchemaPartialCreation() { @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -192,7 +251,7 @@ void addAndSearchPoormansBench() { int docsPerAdd = 12; // 128; int rounds = 3; - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = new CassandraVectorStore( storeBuilder(context, List.of()).withFixedThreadPoolExecutorSize(nThreads).build(), @@ -231,7 +290,7 @@ void addAndSearchPoormansBench() { @Test void searchWithPartitionFilter() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -282,7 +341,7 @@ void searchWithPartitionFilter() throws InterruptedException { @Test void unsearchableFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -301,7 +360,7 @@ void unsearchableFilters() throws InterruptedException { @Test void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -366,7 +425,7 @@ void searchWithFilters() throws InterruptedException { @Test void searchWithFilterOnPrimaryKeys() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { List overrides = List.of( new SchemaColumn("title", DataTypes.TEXT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED), @@ -402,7 +461,7 @@ void searchWithFilterOnPrimaryKeys() throws InterruptedException { @Test void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -453,7 +512,7 @@ void documentUpdate() { @Test void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -483,27 +542,6 @@ void searchWithThreshold() { }); } - @SpringBootConfiguration - @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) - public static class TestApplication { - - @Bean - public EmbeddingModel embeddingModel() { - // default is ONNX all-MiniLM-L6-v2 - return new TransformersEmbeddingModel(); - } - - @Bean - public CqlSession cqlSession() { - return new CqlSessionBuilder() - // comment next two lines out to connect to a local C* cluster - .addContactPoint(cassandraContainer.getContactPoint()) - .withLocalDatacenter(cassandraContainer.getLocalDatacenter()) - .build(); - } - - } - private StoreWrapper createStore(ApplicationContext context, boolean disallowSchemaCreation) throws IOException { @@ -526,64 +564,6 @@ private StoreWrapper createSto return new StoreWrapper(new CassandraVectorStore(conf, context.getBean(EmbeddingModel.class)), conf); } - static CassandraVectorStoreConfig.Builder storeBuilder(ApplicationContext context, - List columnOverrides) throws IOException { - - Optional wikiOverride = columnOverrides.stream() - .filter((f) -> "wiki".equals(f.name())) - .findFirst(); - - Optional langOverride = columnOverrides.stream() - .filter((f) -> "language".equals(f.name())) - .findFirst(); - - Optional titleOverride = columnOverrides.stream() - .filter((f) -> "title".equals(f.name())) - .findFirst(); - - Optional chunkNoOverride = columnOverrides.stream() - .filter((f) -> "chunk_no".equals(f.name())) - .findFirst(); - - SchemaColumn wikiSC = wikiOverride.orElse(new SchemaColumn("wiki", DataTypes.TEXT)); - SchemaColumn langSC = langOverride.orElse(new SchemaColumn("language", DataTypes.TEXT)); - SchemaColumn titleSC = titleOverride.orElse(new SchemaColumn("title", DataTypes.TEXT)); - SchemaColumn chunkNoSC = chunkNoOverride.orElse(new SchemaColumn("chunk_no", DataTypes.INT)); - - List partitionKeys = List.of(wikiSC, langSC, titleSC); - List clusteringKeys = List.of(chunkNoSC); - - CassandraVectorStoreConfig.Builder builder = CassandraVectorStoreConfig.builder() - .withCqlSession(context.getBean(CqlSession.class)) - .withKeyspaceName("test_wikidata") - .withTableName("articles") - .withPartitionKeys(partitionKeys) - .withClusteringKeys(clusteringKeys) - .withContentColumnName("body") - .withEmbeddingColumnName("all_minilm_l6_v2_embedding") - .withIndexName("all_minilm_l6_v2_ann") - - .addMetadataColumns(new SchemaColumn("revision", DataTypes.INT), - new SchemaColumn("id", DataTypes.INT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED)) - - // this store uses '§¶' as a deliminator in the document id between db columns - // 'title' and 'chunk_no' - .withPrimaryKeyTranslator((List primaryKeys) -> { - if (primaryKeys.isEmpty()) { - return "test§¶0"; - } - return format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); - }) - .withDocumentIdTranslator((id) -> { - String[] parts = id.split("§¶"); - String title = parts[0]; - int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; - return List.of("simplewiki", "en", title, chunk_no); - }); - - return builder; - } - private void executeCqlFile(ApplicationContext context, String filename) throws IOException { logger.info("executing {}", filename); @@ -599,7 +579,29 @@ private void executeCqlFile(ApplicationContext context, String filename) throws } } + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + public static class TestApplication { + + @Bean + public EmbeddingModel embeddingModel() { + // default is ONNX all-MiniLM-L6-v2 + return new TransformersEmbeddingModel(); + } + + @Bean + public CqlSession cqlSession() { + return new CqlSessionBuilder() + // comment next two lines out to connect to a local C* cluster + .addContactPoint(cassandraContainer.getContactPoint()) + .withLocalDatacenter(cassandraContainer.getLocalDatacenter()) + .build(); + } + + } + public record StoreWrapper(K store, V conf) { + } } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java index 2c757151640..03dd67c27dc 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -29,11 +30,11 @@ import com.datastax.oss.driver.api.core.type.DataTypes; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.springframework.ai.CassandraImage; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.springframework.ai.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -84,6 +85,26 @@ private static String getText(String uri) { } } + private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { + return CassandraVectorStoreConfig.builder() + .withCqlSession(cqlSession) + .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); + } + + private static CassandraVectorStore createTestStore(ApplicationContext context, SchemaColumn... metadataFields) { + CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) + .addMetadataColumns(metadataFields); + + return createTestStore(context, builder); + } + + private static CassandraVectorStore createTestStore(ApplicationContext context, + CassandraVectorStoreConfig.Builder builder) { + CassandraVectorStoreConfig conf = builder.build(); + conf.dropKeyspace(); + return new CassandraVectorStore(conf, context.getBean(EmbeddingModel.class)); + } + @Test void ensureBeanGetsCreated() { this.contextRunner.run(context -> { @@ -96,7 +117,7 @@ void ensureBeanGetsCreated() { @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createTestStore(context, new SchemaColumn("meta1", DataTypes.TEXT), new SchemaColumn("meta2", DataTypes.TEXT))) { @@ -132,7 +153,7 @@ void addAndSearch() { @Test void addAndSearchReturnEmbeddings() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) .returnEmbeddings(); @@ -168,7 +189,7 @@ void addAndSearchReturnEmbeddings() { @Test void searchWithPartitionFilter() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createTestStore(context, new SchemaColumn("year", DataTypes.SMALLINT, SchemaColumnTags.INDEXED))) { @@ -224,7 +245,7 @@ void searchWithPartitionFilter() throws InterruptedException { @Test void unsearchableFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = context.getBean(CassandraVectorStore.class)) { var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -251,7 +272,7 @@ void unsearchableFilters() throws InterruptedException { @Test void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createTestStore(context, new SchemaColumn("country", DataTypes.TEXT, SchemaColumnTags.INDEXED), @@ -314,7 +335,7 @@ void searchWithFilters() throws InterruptedException { @Test void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = context.getBean(CassandraVectorStore.class)) { Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", @@ -351,7 +372,7 @@ void documentUpdate() { @Test void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = context.getBean(CassandraVectorStore.class)) { store.add(documents()); @@ -414,24 +435,4 @@ public CqlSession cqlSession() { } - private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { - return CassandraVectorStoreConfig.builder() - .withCqlSession(cqlSession) - .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); - } - - private static CassandraVectorStore createTestStore(ApplicationContext context, SchemaColumn... metadataFields) { - CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) - .addMetadataColumns(metadataFields); - - return createTestStore(context, builder); - } - - private static CassandraVectorStore createTestStore(ApplicationContext context, - CassandraVectorStoreConfig.Builder builder) { - CassandraVectorStoreConfig conf = builder.build(); - conf.dropKeyspace(); - return new CassandraVectorStore(conf, context.getBean(EmbeddingModel.class)); - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java index a8cc74dfbcf..e92349efaa4 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.type.DataTypes; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -40,17 +49,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.CassandraContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import com.datastax.oss.driver.api.core.CqlSession; -import com.datastax.oss.driver.api.core.CqlSessionBuilder; -import com.datastax.oss.driver.api.core.type.DataTypes; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -80,16 +80,22 @@ public static String getText(String uri) { } } + private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { + return CassandraVectorStoreConfig.builder() + .withCqlSession(cqlSession) + .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); + } + @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() @@ -193,10 +199,4 @@ public CqlSession cqlSession() { } - private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { - return CassandraVectorStoreConfig.builder() - .withCqlSession(cqlSession) - .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java index 301b61c494a..7189351da95 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_full_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_full_schema.cql index c6f6cf17a56..86a8d93fbd5 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_full_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_full_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_0_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_0_schema.cql index d2f3fcd622d..42724e314e6 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_0_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_0_schema.cql @@ -1 +1,17 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_1_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_1_schema.cql index cb1a535824c..5b0064c301d 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_1_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_1_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_2_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_2_schema.cql index 5853b2274f4..759374499af 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_2_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_2_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_3_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_3_schema.cql index a605116ca3e..673a77e68ac 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_3_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_3_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_4_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_4_schema.cql index 68b4583c491..564eb23330f 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_4_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_4_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-chroma-store/pom.xml b/vector-stores/spring-ai-chroma-store/pom.xml index c057ea33967..51b1fa4d89e 100644 --- a/vector-stores/spring-ai-chroma-store/pom.xml +++ b/vector-stores/spring-ai-chroma-store/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java index ced7e44ff7d..8d4e5b0369c 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chroma; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + import org.springframework.ai.chroma.ChromaApi.QueryRequest.Include; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; @@ -36,10 +40,6 @@ import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - /** * Single-class Chroma API implementation based on the (unofficial) Chroma REST API. * @@ -54,10 +54,10 @@ public class ChromaApi { // Regular expression pattern that looks for a message. private static Pattern MESSAGE_ERROR_PATTERN = Pattern.compile("\"message\":\"(.*?)\""); - private RestClient restClient; - private final ObjectMapper objectMapper; + private RestClient restClient; + private String keyToken; public ChromaApi(String baseUrl) { @@ -99,164 +99,6 @@ public ChromaApi withBasicAuthCredentials(String username, String password) { return this; } - /** - * Chroma embedding collection. - * - * @param id Collection Id. - * @param name The name of the collection. - * @param metadata Metadata associated with the collection. - */ - public record Collection(String id, String name, Map metadata) { - } - - /** - * Request to create a new collection with the given name and metadata. - * - * @param name The name of the collection to create. - * @param metadata Optional metadata to associate with the collection. - */ - public record CreateCollectionRequest(String name, Map metadata) { - public CreateCollectionRequest(String name) { - this(name, new HashMap<>(Map.of("hnsw:space", "cosine"))); - } - } - - /** - * Add embeddings to the chroma data store. - * - * @param ids The ids of the embeddings to add. - * @param embeddings The embeddings to add. - * @param metadata The metadata to associate with the embeddings. When querying, you - * can filter on this metadata. - * @param documents The documents contents to associate with the embeddings. - */ - public record AddEmbeddingsRequest(List ids, List embeddings, - @JsonProperty("metadatas") List> metadata, List documents) { - - // Convenance for adding a single embedding. - public AddEmbeddingsRequest(String id, float[] embedding, Map metadata, String document) { - this(List.of(id), List.of(embedding), List.of(metadata), List.of(document)); - } - } - - /** - * Request to delete embedding from a collection. - * - * @param ids The ids of the embeddings to delete. (Optional) - * @param where Condition to filter items to delete based on metadata values. - * (Optional) - */ - public record DeleteEmbeddingsRequest(List ids, Map where) { - public DeleteEmbeddingsRequest(List ids) { - this(ids, Map.of()); - } - } - - /** - * Get embeddings from a collection. - * - * @param ids IDs of the embeddings to get. - * @param where Condition to filter results based on metadata values. - * @param limit Limit on the number of collection embeddings to get. - * @param offset Offset on the embeddings to get. - * @param include A list of what to include in the results. Can contain "embeddings", - * "metadatas", "documents", "distances". Ids are always included. Defaults to - * [metadatas, documents, distances]. - */ - public record GetEmbeddingsRequest(List ids, Map where, int limit, int offset, - List include) { - - public GetEmbeddingsRequest(List ids) { - this(ids, Map.of(), 10, 0, Include.all); - } - - public GetEmbeddingsRequest(List ids, Map where) { - this(ids, where, 10, 0, Include.all); - } - - public GetEmbeddingsRequest(List ids, Map where, int limit, int offset) { - this(ids, where, limit, offset, Include.all); - } - } - - /** - * Object containing the get embedding results. - * - * @param ids List of document ids. One for each returned document. - * @param embeddings List of document embeddings. One for each returned document. - * @param documents List of document contents. One for each returned document. - * @param metadata List of document metadata. One for each returned document. - */ - public record GetEmbeddingResponse(List ids, List embeddings, List documents, - @JsonProperty("metadatas") List> metadata) { - } - - /** - * Request to get the nResults nearest neighbor embeddings for provided - * queryEmbeddings. - * - * @param queryEmbeddings The embeddings to get the closes neighbors of. - * @param nResults The number of neighbors to return for each query_embedding or - * query_texts. - * @param where Condition to filter results based on metadata values. - * @param include A list of what to include in the results. Can contain "embeddings", - * "metadatas", "documents", "distances". Ids are always included. Defaults to - * [metadatas, documents, distances]. - */ - public record QueryRequest(@JsonProperty("query_embeddings") List queryEmbeddings, - @JsonProperty("n_results") int nResults, Map where, List include) { - - public enum Include { - - @JsonProperty("metadatas") - METADATAS, - - @JsonProperty("documents") - DOCUMENTS, - - @JsonProperty("distances") - DISTANCES, - - @JsonProperty("embeddings") - EMBEDDINGS; - - public static final List all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS); - - } - - /** - * Convenience to query for a single embedding instead of a batch of embeddings. - */ - public QueryRequest(float[] queryEmbedding, int nResults) { - this(List.of(queryEmbedding), nResults, Map.of(), Include.all); - } - - public QueryRequest(float[] queryEmbedding, int nResults, Map where) { - this(List.of(queryEmbedding), nResults, where, Include.all); - } - } - - /** - * A QueryResponse object containing the query results. - * - * @param ids List of list of document ids. One for each returned document. - * @param embeddings List of list of document embeddings. One for each returned - * document. - * @param documents List of list of document contents. One for each returned document. - * @param metadata List of list of document metadata. One for each returned document. - * @param distances List of list of search distances. One for each returned document. - */ - public record QueryResponse(List> ids, List> embeddings, List> documents, - @JsonProperty("metadatas") List>> metadata, List> distances) { - } - - /** - * Single query embedding response. - */ - public record Embedding(String id, float[] embedding, String document, Map metadata, - Double distances) { - } - public List toEmbeddingResponseList(QueryResponse queryResponse) { List result = new ArrayList<>(); @@ -271,10 +113,6 @@ public List toEmbeddingResponseList(QueryResponse queryResponse) { return result; } - // - // Chroma Client API (https://docs.trychroma.com/js_reference/Client) - // - public Collection createCollection(CreateCollectionRequest createCollectionRequest) { return this.restClient.post() @@ -330,10 +168,6 @@ public Collection getCollection(String collectionName) { } } - private static class CollectionList extends ArrayList { - - } - public List listCollections() { return this.restClient.get() @@ -344,10 +178,6 @@ public List listCollections() { .getBody(); } - // - // Chroma Collection API (https://docs.trychroma.com/js_reference/Collection) - // - public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding) { this.restClient.post() @@ -366,6 +196,7 @@ public List deleteEmbeddings(String collectionId, DeleteEmbeddingsReques .body(deleteRequest) .retrieve() .toEntity(new ParameterizedTypeReference>() { + }) .getBody(); } @@ -391,6 +222,10 @@ public QueryResponse queryCollection(String collectionId, QueryRequest queryRequ .getBody(); } + // + // Chroma Client API (https://docs.trychroma.com/js_reference/Client) + // + public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) { return this.restClient.post() @@ -442,4 +277,181 @@ private String getErrorMessage(HttpStatusCodeException e) { return ""; } + /** + * Chroma embedding collection. + * + * @param id Collection Id. + * @param name The name of the collection. + * @param metadata Metadata associated with the collection. + */ + public record Collection(String id, String name, Map metadata) { + + } + + /** + * Request to create a new collection with the given name and metadata. + * + * @param name The name of the collection to create. + * @param metadata Optional metadata to associate with the collection. + */ + public record CreateCollectionRequest(String name, Map metadata) { + + public CreateCollectionRequest(String name) { + this(name, new HashMap<>(Map.of("hnsw:space", "cosine"))); + } + + } + + // + // Chroma Collection API (https://docs.trychroma.com/js_reference/Collection) + // + + /** + * Add embeddings to the chroma data store. + * + * @param ids The ids of the embeddings to add. + * @param embeddings The embeddings to add. + * @param metadata The metadata to associate with the embeddings. When querying, you + * can filter on this metadata. + * @param documents The documents contents to associate with the embeddings. + */ + public record AddEmbeddingsRequest(List ids, List embeddings, + @JsonProperty("metadatas") List> metadata, List documents) { + + // Convenance for adding a single embedding. + public AddEmbeddingsRequest(String id, float[] embedding, Map metadata, String document) { + this(List.of(id), List.of(embedding), List.of(metadata), List.of(document)); + } + + } + + /** + * Request to delete embedding from a collection. + * + * @param ids The ids of the embeddings to delete. (Optional) + * @param where Condition to filter items to delete based on metadata values. + * (Optional) + */ + public record DeleteEmbeddingsRequest(List ids, Map where) { + + public DeleteEmbeddingsRequest(List ids) { + this(ids, Map.of()); + } + + } + + /** + * Get embeddings from a collection. + * + * @param ids IDs of the embeddings to get. + * @param where Condition to filter results based on metadata values. + * @param limit Limit on the number of collection embeddings to get. + * @param offset Offset on the embeddings to get. + * @param include A list of what to include in the results. Can contain "embeddings", + * "metadatas", "documents", "distances". Ids are always included. Defaults to + * [metadatas, documents, distances]. + */ + public record GetEmbeddingsRequest(List ids, Map where, int limit, int offset, + List include) { + + public GetEmbeddingsRequest(List ids) { + this(ids, Map.of(), 10, 0, Include.all); + } + + public GetEmbeddingsRequest(List ids, Map where) { + this(ids, where, 10, 0, Include.all); + } + + public GetEmbeddingsRequest(List ids, Map where, int limit, int offset) { + this(ids, where, limit, offset, Include.all); + } + + } + + /** + * Object containing the get embedding results. + * + * @param ids List of document ids. One for each returned document. + * @param embeddings List of document embeddings. One for each returned document. + * @param documents List of document contents. One for each returned document. + * @param metadata List of document metadata. One for each returned document. + */ + public record GetEmbeddingResponse(List ids, List embeddings, List documents, + @JsonProperty("metadatas") List> metadata) { + + } + + /** + * Request to get the nResults nearest neighbor embeddings for provided + * queryEmbeddings. + * + * @param queryEmbeddings The embeddings to get the closes neighbors of. + * @param nResults The number of neighbors to return for each query_embedding or + * query_texts. + * @param where Condition to filter results based on metadata values. + * @param include A list of what to include in the results. Can contain "embeddings", + * "metadatas", "documents", "distances". Ids are always included. Defaults to + * [metadatas, documents, distances]. + */ + public record QueryRequest(@JsonProperty("query_embeddings") List queryEmbeddings, + @JsonProperty("n_results") int nResults, Map where, List include) { + + /** + * Convenience to query for a single embedding instead of a batch of embeddings. + */ + public QueryRequest(float[] queryEmbedding, int nResults) { + this(List.of(queryEmbedding), nResults, Map.of(), Include.all); + } + + public QueryRequest(float[] queryEmbedding, int nResults, Map where) { + this(List.of(queryEmbedding), nResults, where, Include.all); + } + + public enum Include { + + @JsonProperty("metadatas") + METADATAS, + + @JsonProperty("documents") + DOCUMENTS, + + @JsonProperty("distances") + DISTANCES, + + @JsonProperty("embeddings") + EMBEDDINGS; + + public static final List all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS); + + } + + } + + /** + * A QueryResponse object containing the query results. + * + * @param ids List of list of document ids. One for each returned document. + * @param embeddings List of list of document embeddings. One for each returned + * document. + * @param documents List of list of document contents. One for each returned document. + * @param metadata List of list of document metadata. One for each returned document. + * @param distances List of list of search distances. One for each returned document. + */ + public record QueryResponse(List> ids, List> embeddings, List> documents, + @JsonProperty("metadatas") List>> metadata, List> distances) { + + } + + /** + * Single query embedding response. + */ + public record Embedding(String id, float[] embedding, String document, Map metadata, + Double distances) { + + } + + private static class CollectionList extends ArrayList { + + } + } diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaFilterExpressionConverter.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaFilterExpressionConverter.java index cd60e83dc6c..526fc30cf31 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaFilterExpressionConverter.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter; diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index f7778680652..7afb13c4bda 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -22,6 +22,11 @@ import java.util.Map; import java.util.Optional; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.chroma.ChromaApi.AddEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.DeleteEmbeddingsRequest; @@ -43,11 +48,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.json.JsonMapper; -import io.micrometer.observation.ObservationRegistry; - /** * {@link ChromaVectorStore} is a concrete implementation of the {@link VectorStore} * interface. It is responsible for adding, deleting, and searching documents based on @@ -229,4 +229,4 @@ public Builder createObservationContextBuilder(String operationName) { .withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java index 9cbbaa4a49a..51208d3d9cd 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java index c1c934df7de..bfc84340294 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chroma; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chroma; import java.util.List; import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi.AddEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.Collection; @@ -31,9 +34,8 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -52,57 +54,58 @@ public class ChromaApiIT { @BeforeEach public void beforeEach() { - chroma.listCollections().stream().forEach(c -> chroma.deleteCollection(c.name())); + this.chroma.listCollections().stream().forEach(c -> this.chroma.deleteCollection(c.name())); } @Test public void testClientWithMetadata() { Map metadata = Map.of("hnsw:space", "cosine", "hnsw:M", 5); - var newCollection = chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata)); + var newCollection = this.chroma + .createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata)); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); } @Test public void testClient() { - var newCollection = chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); - var getCollection = chroma.getCollection("TestCollection"); + var getCollection = this.chroma.getCollection("TestCollection"); assertThat(getCollection).isNotNull(); assertThat(getCollection.name()).isEqualTo("TestCollection"); assertThat(getCollection.id()).isEqualTo(newCollection.id()); - List collections = chroma.listCollections(); + List collections = this.chroma.listCollections(); assertThat(collections).hasSize(1); assertThat(collections.get(0).id()).isEqualTo(newCollection.id()); - chroma.deleteCollection(newCollection.name()); - assertThat(chroma.listCollections()).hasSize(0); + this.chroma.deleteCollection(newCollection.name()); + assertThat(this.chroma.listCollections()).hasSize(0); } @Test public void testCollection() { - var newCollection = chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); - assertThat(chroma.countEmbeddings(newCollection.id())).isEqualTo(0); + var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(0); var addEmbeddingRequest = new AddEmbeddingsRequest(List.of("id1", "id2"), List.of(new float[] { 1f, 1f, 1f }, new float[] { 2f, 2f, 2f }), List.of(Map.of(), Map.of("key1", "value1", "key2", true, "key3", 23.4)), List.of("Hello World", "Big World")); - chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); + this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); var addEmbeddingRequest2 = new AddEmbeddingsRequest("id3", new float[] { 3f, 3f, 3f }, Map.of("key1", "value1", "key2", true, "key3", 23.4), "Big World"); - chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); + this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); - assertThat(chroma.countEmbeddings(newCollection.id())).isEqualTo(3); + assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(3); - var queryResult = chroma.queryCollection(newCollection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, chroma.where(""" + var queryResult = this.chroma.queryCollection(newCollection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" { "key2" : { "$eq": true } } @@ -111,14 +114,14 @@ public void testCollection() { assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id2", "id3"); // Update existing embedding. - chroma.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, + this.chroma.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World")); - var result = chroma.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); + var result = this.chroma.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); assertThat(result.ids().get(0)).isEqualTo("id2"); - queryResult = chroma.queryCollection(newCollection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, chroma.where(""" + queryResult = this.chroma.queryCollection(newCollection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" { "key2" : { "$eq": true } } @@ -130,7 +133,7 @@ public void testCollection() { @Test public void testQueryWhere() { - var collection = chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var collection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); var add1 = new AddEmbeddingsRequest("id1", new float[] { 1f, 1f, 1f }, Map.of("country", "BG", "active", true, "price", 23.4, "year", 2020), @@ -143,24 +146,24 @@ public void testQueryWhere() { Map.of("country", "BG", "active", false, "price", 40.1, "year", 2023), "The World is Big and Salvation Lurks Around the Corner"); - chroma.upsertEmbeddings(collection.id(), add1); - chroma.upsertEmbeddings(collection.id(), add2); - chroma.upsertEmbeddings(collection.id(), add3); + this.chroma.upsertEmbeddings(collection.id(), add1); + this.chroma.upsertEmbeddings(collection.id(), add2); + this.chroma.upsertEmbeddings(collection.id(), add3); - assertThat(chroma.countEmbeddings(collection.id())).isEqualTo(3); + assertThat(this.chroma.countEmbeddings(collection.id())).isEqualTo(3); - var queryResult = chroma.queryCollection(collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); + var queryResult = this.chroma.queryCollection(collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); assertThat(queryResult.ids().get(0)).hasSize(3); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3"); - var chromaEmbeddings = chroma.toEmbeddingResponseList(queryResult); + var chromaEmbeddings = this.chroma.toEmbeddingResponseList(queryResult); assertThat(chromaEmbeddings).hasSize(3); assertThat(chromaEmbeddings).hasSize(3); - queryResult = chroma.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, chroma.where(""" + queryResult = this.chroma.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" { "$and" : [ {"country" : { "$eq": "BG"}}, @@ -171,8 +174,8 @@ public void testQueryWhere() { assertThat(queryResult.ids().get(0)).hasSize(2); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id3"); - queryResult = chroma.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, chroma.where(""" + queryResult = this.chroma.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" { "$and" : [ {"country" : { "$eq": "BG"}}, diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java index b0083010e32..4c748dabb9a 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.MountableFile; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; @@ -32,10 +36,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.MountableFile; + +import static org.assertj.core.api.Assertions.assertThat; /** * ChromaDB with Basic Authentication: @@ -68,7 +70,7 @@ public class BasicAuthChromaWhereIT { @Test public void withInFiltersExpressions1() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java index 82cd3bfbd45..ec099430cff 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.util.Collections; import java.util.List; @@ -23,6 +22,10 @@ import java.util.UUID; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; @@ -34,9 +37,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -49,6 +51,10 @@ public class ChromaVectorStoreIT { @Container static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); + List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")), @@ -57,29 +63,25 @@ public class ChromaVectorStoreIT { "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression", Collections.singletonMap("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class) - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); - @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); List results2 = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results2).hasSize(0); @@ -89,7 +91,7 @@ public void addAndSearch() { @Test public void addAndSearchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -129,7 +131,7 @@ public void addAndSearchWithFilters() { public void documentUpdateTest() { // Note ,using OpenAI to calculate embeddings - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -170,11 +172,11 @@ public void documentUpdateTest() { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); var request = SearchRequest.query("Great").withTopK(5); List fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll()); @@ -189,14 +191,14 @@ public void searchThresholdTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); }); } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java index 89e9eab2e00..3b39051eecc 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; @@ -42,13 +48,8 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -81,13 +82,13 @@ public static String getText(String uri) { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ChromaVectorStore vectorStore = context.getBean(ChromaVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java index 84d7a0bb4bd..19ebd4b9de3 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; @@ -32,9 +35,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; /** * ChromaDB with static API Token Authentication: @@ -69,7 +71,7 @@ public class TokenSecuredChromaWhereIT { @Test public void withInFiltersExpressions1() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -93,7 +95,7 @@ public void withInFiltersExpressions1() { @Test public void withInFiltersExpressions() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); diff --git a/vector-stores/spring-ai-elasticsearch-store/pom.xml b/vector-stores/spring-ai-elasticsearch-store/pom.xml index 6dea5967cd5..89cb9e6e6de 100644 --- a/vector-stores/spring-ai-elasticsearch-store/pom.xml +++ b/vector-stores/spring-ai-elasticsearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java index c8d0701bfb2..e7b2c5a01ae 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.Filter.Expression; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; +package org.springframework.ai.vectorstore; import java.text.ParseException; import java.text.SimpleDateFormat; @@ -27,6 +23,11 @@ import java.util.TimeZone; import java.util.regex.Pattern; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; + /** * ElasticsearchAiSearchFilterExpressionConverter is a class that converts * Filter.Expression objects into Elasticsearch query string representation. It extends diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index bf9932b37b7..32731f754c4 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static java.lang.Math.sqrt; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.util.List; @@ -24,10 +23,22 @@ import java.util.Optional; import java.util.stream.Collectors; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch.core.BulkRequest; +import co.elastic.clients.elasticsearch.core.BulkResponse; +import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; +import co.elastic.clients.elasticsearch.core.search.Hit; +import co.elastic.clients.json.jackson.JacksonJsonpMapper; import co.elastic.clients.transport.Version; +import co.elastic.clients.transport.rest_client.RestClientTransport; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; import org.elasticsearch.client.RestClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -45,18 +56,7 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; - -import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch.core.BulkRequest; -import co.elastic.clients.elasticsearch.core.BulkResponse; -import co.elastic.clients.elasticsearch.core.SearchResponse; -import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; -import co.elastic.clients.elasticsearch.core.search.Hit; -import co.elastic.clients.json.jackson.JacksonJsonpMapper; -import co.elastic.clients.transport.rest_client.RestClientTransport; -import io.micrometer.observation.ObservationRegistry; +import static java.lang.Math.sqrt; /** * The ElasticsearchVectorStore class implements the VectorStore interface and provides @@ -79,6 +79,10 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class); + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + SimilarityFunction.cosine, VectorStoreSimilarityMetric.COSINE, SimilarityFunction.l2_norm, + VectorStoreSimilarityMetric.EUCLIDEAN, SimilarityFunction.dot_product, VectorStoreSimilarityMetric.DOT); + private final EmbeddingModel embeddingModel; private final ElasticsearchClient elasticsearchClient; @@ -176,14 +180,14 @@ public List doSimilaritySearch(SearchRequest searchRequest) { try { float threshold = (float) searchRequest.getSimilarityThreshold(); // reverting l2_norm distance to its original value - if (options.getSimilarity().equals(SimilarityFunction.l2_norm)) { + if (this.options.getSimilarity().equals(SimilarityFunction.l2_norm)) { threshold = 1 - threshold; } final float finalThreshold = threshold; float[] vectors = this.embeddingModel.embed(searchRequest.getQuery()); - SearchResponse res = elasticsearchClient.search( - sr -> sr.index(options.getIndexName()) + SearchResponse res = this.elasticsearchClient.search( + sr -> sr.index(this.options.getIndexName()) .knn(knn -> knn.queryVector(EmbeddingUtils.toList(vectors)) .similarity(finalThreshold) .k((long) searchRequest.getTopK()) @@ -215,7 +219,7 @@ private Document toDocument(Hit hit) { // more info on score/distance calculation // https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search private float calculateDistance(Float score) { - switch (options.getSimilarity()) { + switch (this.options.getSimilarity()) { case l2_norm: // the returned value of l2_norm is the opposite of the other functions // (closest to zero means more accurate), so to make it consistent @@ -230,7 +234,7 @@ private float calculateDistance(Float score) { public boolean indexExists() { try { - return this.elasticsearchClient.indices().exists(ex -> ex.index(options.getIndexName())).value(); + return this.elasticsearchClient.indices().exists(ex -> ex.index(this.options.getIndexName())).value(); } catch (IOException e) { throw new RuntimeException(e); @@ -240,9 +244,10 @@ public boolean indexExists() { private void createIndexMapping() { try { this.elasticsearchClient.indices() - .create(cr -> cr.index(options.getIndexName()) - .mappings(map -> map.properties("embedding", p -> p.denseVector( - dv -> dv.similarity(options.getSimilarity().toString()).dims(options.getDimensions()))))); + .create(cr -> cr.index(this.options.getIndexName()) + .mappings(map -> map.properties("embedding", + p -> p.denseVector(dv -> dv.similarity(this.options.getSimilarity().toString()) + .dims(this.options.getDimensions()))))); } catch (IOException e) { throw new RuntimeException(e); @@ -267,10 +272,6 @@ public Builder createObservationContextBuilder(String operationName) { .withSimilarityMetric(getSimilarityMetric()); } - private static Map SIMILARITY_TYPE_MAPPING = Map.of( - SimilarityFunction.cosine, VectorStoreSimilarityMetric.COSINE, SimilarityFunction.l2_norm, - VectorStoreSimilarityMetric.EUCLIDEAN, SimilarityFunction.dot_product, VectorStoreSimilarityMetric.DOT); - private String getSimilarityMetric() { if (!SIMILARITY_TYPE_MAPPING.containsKey(this.options.getSimilarity())) { return this.options.getSimilarity().name(); diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java index c685122461c..8aee18ac030 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; /** @@ -40,7 +41,7 @@ public class ElasticsearchVectorStoreOptions { private SimilarityFunction similarity = SimilarityFunction.cosine; public String getIndexName() { - return indexName; + return this.indexName; } public void setIndexName(String indexName) { @@ -48,7 +49,7 @@ public void setIndexName(String indexName) { } public int getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(int dims) { @@ -56,7 +57,7 @@ public void setDimensions(int dims) { } public SimilarityFunction getSimilarity() { - return similarity; + return this.similarity; } public void setSimilarity(SimilarityFunction similarity) { diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java index 86fc84c01c0..b28e7313d41 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; /** diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverterTest.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverterTest.java index 3822096464c..7a6f737ee60 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverterTest.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Date; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import java.util.Date; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; @@ -38,25 +40,25 @@ class ElasticsearchAiSearchFilterExpressionConverterTest { @Test public void testDate() { - String vectorExpr = converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), + String vectorExpr = this.converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value(new Date(1704637752148L)))); assertThat(vectorExpr).isEqualTo("metadata.activationDate:2024-01-07T14:29:12Z"); - vectorExpr = converter.convertExpression( + vectorExpr = this.converter.convertExpression( new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value("1970-01-01T00:00:02Z"))); assertThat(vectorExpr).isEqualTo("metadata.activationDate:1970-01-01T00:00:02Z"); } @Test public void testEQ() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country:BG"); } @Test public void tesEqAndGte() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("genre"), new Filter.Value("drama")), new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)))); assertThat(vectorExpr).isEqualTo("metadata.genre:drama AND metadata.year:>=2020"); @@ -64,14 +66,14 @@ public void tesEqAndGte() { @Test public void tesIn() { - String vectorExpr = converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), + String vectorExpr = this.converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), new Filter.Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("(metadata.genre:comedy OR documentary OR drama)"); } @Test public void testNe() { - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")), @@ -81,7 +83,7 @@ public void testNe() { @Test public void testGroup() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Group(new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")))), @@ -92,7 +94,7 @@ public void testGroup() { @Test public void tesBoolean() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("isOpen"), new Filter.Value(true)), new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020))), new Filter.Expression(IN, new Filter.Key("country"), new Filter.Value(List.of("BG", "NL", "US"))))); @@ -103,7 +105,7 @@ public void tesBoolean() { @Test public void testDecimal() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(GTE, new Filter.Key("temperature"), new Filter.Value(-15.6)), new Filter.Expression(LTE, new Filter.Key("temperature"), new Filter.Value(20.13)))); @@ -112,11 +114,11 @@ public void testDecimal() { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("\"country 1 2 3\""), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); - vectorExpr = converter + vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("'country 1 2 3'"), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java index 2697c19506c..db8b68f3b8e 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index c81972baccc..c262c9a4af8 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -121,7 +122,7 @@ public void addAndDeleteDocumentsTest() { assertThat(stats.total().docs().count()).isEqualTo(0L); - vectorStore.add(documents); + vectorStore.add(this.documents); elasticsearchClient.indices().refresh(); stats = elasticsearchClient.indices() .stats(s -> s.index("spring-ai-document-index")) @@ -148,7 +149,7 @@ public void addAndSearchTest(String similarityFunction) { ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, ElasticsearchVectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore @@ -160,14 +161,14 @@ public void addAndSearchTest(String similarityFunction) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore @@ -266,7 +267,7 @@ public void searchWithFilters(String similarityFunction) { assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)), hasSize(0)); @@ -334,7 +335,7 @@ public void searchThresholdTest(String similarityFunction) { ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, ElasticsearchVectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); SearchRequest query = SearchRequest.query("Great Depression").withTopK(50).withSimilarityThresholdAll(); @@ -353,13 +354,13 @@ public void searchThresholdTest(String similarityFunction) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch( diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java index 4a0c9cc58e3..7040b97a600 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,6 +23,15 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord; +import co.elastic.clients.json.jackson.JacksonJsonpMapper; +import co.elastic.clients.transport.rest_client.RestClientTransport; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.apache.http.HttpHost; import org.awaitility.Awaitility; import org.elasticsearch.client.RestClient; @@ -31,6 +39,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.elasticsearch.ElasticsearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -48,21 +60,11 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.elasticsearch.ElasticsearchContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.greaterThan; -import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord; -import co.elastic.clients.json.jackson.JacksonJsonpMapper; -import co.elastic.clients.transport.rest_client.RestClientTransport; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import static org.hamcrest.Matchers.greaterThan;; +; /** * @author Christian Tzolov @@ -92,10 +94,6 @@ public static String getText(String uri) { } } - private ApplicationContextRunner getContextRunner() { - return new ApplicationContextRunner().withUserConfiguration(Config.class); - } - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -103,6 +101,10 @@ public static void beforeAll() { Awaitility.setDefaultTimeout(Duration.ofMinutes(1)); } + private ApplicationContextRunner getContextRunner() { + return new ApplicationContextRunner().withUserConfiguration(Config.class); + } + @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { @@ -124,7 +126,7 @@ void observationVectorStoreAddAndQueryOperations() { TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-gemfire-store/pom.xml b/vector-stores/spring-ai-gemfire-store/pom.xml index e2dcdff04ea..25ab313c88b 100644 --- a/vector-stores/spring-ai-gemfire-store/pom.xml +++ b/vector-stores/spring-ai-gemfire-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java index 416e550532d..7f3390d6d2d 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -16,17 +16,22 @@ package org.springframework.ai.vectorstore; -import static org.springframework.http.HttpStatus.BAD_REQUEST; -import static org.springframework.http.HttpStatus.NOT_FOUND; - import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.util.annotation.NonNull; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -47,14 +52,8 @@ import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.util.UriComponentsBuilder; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.micrometer.observation.ObservationRegistry; -import reactor.util.annotation.NonNull; +import static org.springframework.http.HttpStatus.BAD_REQUEST; +import static org.springframework.http.HttpStatus.NOT_FOUND; /** * A VectorStore implementation backed by GemFire. This store supports creating, updating, diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java index 806497e25a4..3d204767a50 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java index 42ecfc338e8..86c71724a66 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static java.util.concurrent.TimeUnit.MINUTES; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -34,6 +31,7 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -43,6 +41,10 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; + /** * @author Geet Rawat * @author Soby Chacko @@ -53,14 +55,22 @@ public class GemFireVectorStoreIT { public static final String INDEX_NAME = "spring-ai-index1"; - private static GemFireCluster gemFireCluster; - private static final int HTTP_SERVICE_PORT = 9090; private static final int LOCATOR_COUNT = 1; private static final int SERVER_COUNT = 1; + private static GemFireCluster gemFireCluster; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + @AfterAll public static void stopGemFireCluster() { gemFireCluster.close(); @@ -83,11 +93,6 @@ public static void startGemFireCluster() { String.format("localhost[%d]", gemFireCluster.getLocatorPort())); } - List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - public static String getText(String uri) { var resource = new DefaultResourceLoader().getResource(uri); try { @@ -98,15 +103,12 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class); - @Test public void addAndDeleteEmbeddingTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.add(this.documents); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await() .atMost(1, MINUTES) .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3)), @@ -116,9 +118,9 @@ public void addAndDeleteEmbeddingTest() { @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .atMost(1, MINUTES) @@ -127,7 +129,7 @@ public void addAndSearchTest() { List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(5)); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939)" + " was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); @@ -137,7 +139,7 @@ public void addAndSearchTest() { @Test public void documentUpdateTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", @@ -175,9 +177,9 @@ public void documentUpdateTest() { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .atMost(1, MINUTES) @@ -198,7 +200,7 @@ public void searchThresholdTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression " + "(1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java index 5cdd5059194..abf2374c068 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import com.github.dockerjava.api.model.ExposedPort; +import com.github.dockerjava.api.model.PortBinding; +import com.github.dockerjava.api.model.Ports; +import com.vmware.gemfire.testcontainers.GemFireCluster; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.awaitility.Awaitility; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -41,16 +48,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import com.github.dockerjava.api.model.ExposedPort; -import com.github.dockerjava.api.model.PortBinding; -import com.github.dockerjava.api.model.Ports; -import com.vmware.gemfire.testcontainers.GemFireCluster; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; - import static java.util.concurrent.TimeUnit.MINUTES; +import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.hasSize; /** @@ -62,14 +61,22 @@ public class GemFireVectorStoreObservationIT { public static final String TEST_INDEX_NAME = "spring-ai-index1"; - private static GemFireCluster gemFireCluster; - private static final int HTTP_SERVICE_PORT = 9090; private static final int LOCATOR_COUNT = 1; private static final int SERVER_COUNT = 1; + private static GemFireCluster gemFireCluster; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + @AfterAll public static void stopGemFireCluster() { gemFireCluster.close(); @@ -92,14 +99,6 @@ public static void startGemFireCluster() { String.format("localhost[%d]", gemFireCluster.getLocatorPort())); } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - - List documents = List.of( - new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document(getText("classpath:/test/data/time.shelter.txt")), - new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - public static String getText(String uri) { var resource = new DefaultResourceLoader().getResource(uri); try { @@ -113,13 +112,13 @@ public static String getText(String uri) { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-hanadb-store/pom.xml b/vector-stores/spring-ai-hanadb-store/pom.xml index a5fb9e3463a..b794e916490 100644 --- a/vector-stores/spring-ai-hanadb-store/pom.xml +++ b/vector-stores/spring-ai-hanadb-store/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java index d420a206d76..89fbf5e273c 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java @@ -15,14 +15,18 @@ */ package org.springframework.ai.vectorstore; -import com.fasterxml.jackson.core.JsonProcessingException; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; import io.micrometer.observation.ObservationRegistry; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.model.EmbeddingUtils; @@ -34,11 +38,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - /** * The SAP HANA Cloud vector engine offers multiple use cases in AI scenarios. * diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java index 8e69e66e25b..b8b8faff00d 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; /** @@ -37,11 +38,11 @@ public static HanaCloudVectorStoreConfigBuilder builder() { } public String getTableName() { - return tableName; + return this.tableName; } public int getTopK() { - return topK; + return this.topK; } public static class HanaCloudVectorStoreConfigBuilder { @@ -62,8 +63,8 @@ public HanaCloudVectorStoreConfigBuilder topK(int topK) { public HanaCloudVectorStoreConfig build() { HanaCloudVectorStoreConfig config = new HanaCloudVectorStoreConfig(); - config.tableName = tableName; - config.topK = topK; + config.tableName = this.tableName; + config.topK = this.topK; return config; } diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorEntity.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorEntity.java index 439b4f88f8e..7b7109c6a19 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorEntity.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorEntity.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import jakarta.persistence.Column; @@ -39,7 +40,7 @@ public HanaVectorEntity() { } public String get_id() { - return _id; + return this._id; } } diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorRepository.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorRepository.java index 16e33285cf9..e962006f107 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorRepository.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorRepository.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCup.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCup.java index ed1f1e393ed..8853ac37fca 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCup.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCup.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import jakarta.persistence.Column; @@ -31,7 +32,7 @@ public class CricketWorldCup extends HanaVectorEntity { private String content; public String getContent() { - return content; + return this.content; } } diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java index 392e876b4bf..48d97dd2680 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.model.ChatModel; + import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.document.Document; @@ -33,13 +42,6 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Collectors; - /** * @author Rahul Mittal * @since 1.0.0 @@ -74,7 +76,7 @@ public ResponseEntity handleFileUpload(@RequestParam("pdf") MultipartFil Function, List> splitter = new TokenTextSplitter(); List documents = splitter.apply(reader.get()); logger.info("{} documents created from pdf file: {}", documents.size(), pdf.getFilename()); - hanaCloudVectorStore.accept(documents); + this.hanaCloudVectorStore.accept(documents); return ResponseEntity.ok() .body(String.format("%d documents created from pdf file: %s", documents.size(), pdf.getFilename())); } @@ -88,7 +90,7 @@ public Map hanaVectorStoreSearch(@RequestParam(value = "message" var userMessage = new UserMessage(message); Prompt prompt = new Prompt(List.of(similarDocsMessage, userMessage)); - String generation = chatModel.call(prompt).getResult().getOutput().getContent(); + String generation = this.chatModel.call(prompt).getResult().getOutput().getContent(); logger.info("Generation: {}", generation); return Map.of("generation", generation); } diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupRepository.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupRepository.java index 397ea39be64..2a9aac568f3 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupRepository.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupRepository.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.List; + import jakarta.persistence.EntityManager; import jakarta.persistence.PersistenceContext; import jakarta.transaction.Transactional; -import org.springframework.stereotype.Repository; -import java.util.List; +import org.springframework.stereotype.Repository; /** * @author Rahul Mittal @@ -40,7 +42,7 @@ public void save(String tableName, String id, String embedding, String content) VALUES(:_id, TO_REAL_VECTOR(:embedding), :content) """, tableName); - entityManager.createNativeQuery(sql) + this.entityManager.createNativeQuery(sql) .setParameter("_id", id) .setParameter("embedding", embedding) .setParameter("content", content) @@ -54,7 +56,7 @@ public int deleteEmbeddingsById(String tableName, List idList) { DELETE FROM %s WHERE _ID IN (:ids) """, tableName); - return entityManager.createNativeQuery(sql).setParameter("ids", idList).executeUpdate(); + return this.entityManager.createNativeQuery(sql).setParameter("ids", idList).executeUpdate(); } @Override @@ -64,7 +66,7 @@ public int deleteAllEmbeddings(String tableName) { DELETE FROM %s """, tableName); - return entityManager.createNativeQuery(sql).executeUpdate(); + return this.entityManager.createNativeQuery(sql).executeUpdate(); } @Override @@ -74,7 +76,7 @@ public List cosineSimilaritySearch(String tableName, int topK, ORDER BY COSINE_SIMILARITY(EMBEDDING, TO_REAL_VECTOR(:queryEmbedding)) DESC """, tableName); - return entityManager.createNativeQuery(sql, CricketWorldCup.class) + return this.entityManager.createNativeQuery(sql, CricketWorldCup.class) .setParameter("topK", topK) .setParameter("queryEmbedding", queryEmbedding) .getResultList(); diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreIT.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreIT.java index 313bd69bec0..c55f52ceca5 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreIT.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; @@ -59,7 +60,7 @@ public class HanaCloudVectorStoreIT { @Test public void vectorStoreTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(HanaCloudVectorStore.class); int deleteCount = ((HanaCloudVectorStore) vectorStore).purgeEmbeddings(); @@ -128,4 +129,4 @@ public EmbeddingModel embeddingModel() { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaVectorStoreObservationIT.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaVectorStoreObservationIT.java index 78c324a4d3e..c61fb2aaf4e 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,8 +23,12 @@ import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.SpringAiKind; @@ -46,9 +49,7 @@ import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean; import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -62,6 +63,9 @@ public class HanaVectorStoreObservationIT { private static final String TEST_TABLE_NAME = "CRICKET_WORLD_CUP"; + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -77,19 +81,16 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties b/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties index faf788a4ae3..f2d9b9274ad 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties +++ b/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spring.ai.openai.api-key=${OPENAI_API_KEY} spring.ai.openai.embedding.options.model=text-embedding-ada-002 diff --git a/vector-stores/spring-ai-milvus-store/pom.xml b/vector-stores/spring-ai-milvus-store/pom.xml index 8fe98ecde9c..bdf72f777ab 100644 --- a/vector-stores/spring-ai-milvus-store/pom.xml +++ b/vector-stores/spring-ai-milvus-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java index c3d6d1a2dd7..c7e9a093999 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter.Expression; diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java index b1141949a21..fe91142344c 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + import com.alibaba.fastjson.JSONObject; import io.micrometer.observation.ObservationRegistry; import io.milvus.client.MilvusServiceClient; @@ -44,6 +51,7 @@ import io.milvus.response.SearchResultsWrapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -60,12 +68,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - /** * @author Christian Tzolov * @author Soby Chacko @@ -73,8 +75,6 @@ */ public class MilvusVectorStore extends AbstractObservationVectorStore implements InitializingBean { - private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); - public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536; public static final int INVALID_EMBEDDING_DIMENSION = -1; @@ -97,6 +97,12 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements public static final List SEARCH_OUTPUT_FIELDS = List.of(DOC_ID_FIELD_NAME, CONTENT_FIELD_NAME, METADATA_FIELD_NAME); + private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); + + private static Map SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, + VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, MetricType.IP, + VectorStoreSimilarityMetric.DOT); + public final FilterExpressionConverter filterExpressionConverter = new MilvusFilterExpressionConverter(); private final MilvusServiceClient milvusClient; @@ -109,151 +115,6 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements private final BatchingStrategy batchingStrategy; - /** - * Configuration for the Milvus vector store. - */ - public static class MilvusVectorStoreConfig { - - private final String databaseName; - - private final String collectionName; - - private final int embeddingDimension; - - private final IndexType indexType; - - private final MetricType metricType; - - private final String indexParameters; - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - - return new Builder(); - } - - /** - * {@return the default config} - */ - public static MilvusVectorStoreConfig defaultConfig() { - return builder().build(); - } - - private MilvusVectorStoreConfig(Builder builder) { - this.databaseName = builder.databaseName; - this.collectionName = builder.collectionName; - this.embeddingDimension = builder.embeddingDimension; - this.indexType = builder.indexType; - this.metricType = builder.metricType; - this.indexParameters = builder.indexParameters; - } - - public static class Builder { - - private String databaseName = DEFAULT_DATABASE_NAME; - - private String collectionName = DEFAULT_COLLECTION_NAME; - - private int embeddingDimension = INVALID_EMBEDDING_DIMENSION; - - private IndexType indexType = IndexType.IVF_FLAT; - - private MetricType metricType = MetricType.COSINE; - - private String indexParameters = "{\"nlist\":1024}"; - - private Builder() { - } - - /** - * Configures the Milvus metric type to use. Leave {@literal null} or blank to - * use the metric metric: https://milvus.io/docs/metric.md#floating - * @param metricType the metric type to use - * @return this builder - */ - public Builder withMetricType(MetricType metricType) { - Assert.notNull(metricType, "Collection Name must not be empty"); - Assert.isTrue( - metricType == MetricType.IP || metricType == MetricType.L2 || metricType == MetricType.COSINE, - "Only the text metric types IP and L2 are supported"); - - this.metricType = metricType; - return this; - } - - /** - * Configures the Milvus index type to use. Leave {@literal null} or blank to - * use the default index. - * @param indexType the index type to use - * @return this builder - */ - public Builder withIndexType(IndexType indexType) { - this.indexType = indexType; - return this; - } - - /** - * Configures the Milvus index parameters to use. Leave {@literal null} or - * blank to use the default index parameters. - * @param indexParameters the index parameters to use - * @return this builder - */ - public Builder withIndexParameters(String indexParameters) { - this.indexParameters = indexParameters; - return this; - } - - /** - * Configures the Milvus database name to use. Leave {@literal null} or blank - * to use the default database. - * @param databaseName the database name to use - * @return this builder - */ - public Builder withDatabaseName(String databaseName) { - this.databaseName = databaseName; - return this; - } - - /** - * Configures the Milvus collection name to use. Leave {@literal null} or - * blank to use the default collection name. - * @param collectionName the collection name to use - * @return this builder - */ - public Builder withCollectionName(String collectionName) { - this.collectionName = collectionName; - return this; - } - - /** - * Configures the size of the embedding. Defaults to {@literal 1536}, inline - * with OpenAIs embeddings. - * @param newEmbeddingDimension The dimension of the embedding - * @return this builder - */ - public Builder withEmbeddingDimension(int newEmbeddingDimension) { - - Assert.isTrue(newEmbeddingDimension >= 1 && newEmbeddingDimension <= 32768, - "Dimension has to be withing the boundaries 1 and 32768 (inclusively)"); - - this.embeddingDimension = newEmbeddingDimension; - return this; - } - - /** - * {@return the immutable configuration} - */ - public MilvusVectorStoreConfig build() { - return new MilvusVectorStoreConfig(this); - } - - } - - } - public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, boolean initializeSchema) { this(milvusClient, embeddingModel, MilvusVectorStoreConfig.defaultConfig(), initializeSchema, @@ -369,7 +230,7 @@ public List doSimilaritySearch(SearchRequest request) { searchParamBuilder.withExpr(nativeFilterExpressions); } - R respSearch = milvusClient.search(searchParamBuilder.build()); + R respSearch = this.milvusClient.search(searchParamBuilder.build()); if (respSearch.getException() != null) { throw new RuntimeException("Search failed!", respSearch.getException()); @@ -558,10 +419,6 @@ public org.springframework.ai.vectorstore.observation.VectorStoreObservationCont .withNamespace(this.config.databaseName); } - private static Map SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, - VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, MetricType.IP, - VectorStoreSimilarityMetric.DOT); - private String getSimilarityMetric() { if (!SIMILARITY_TYPE_MAPPING.containsKey(this.config.metricType)) { return this.config.metricType.name(); @@ -569,4 +426,149 @@ private String getSimilarityMetric() { return SIMILARITY_TYPE_MAPPING.get(this.config.metricType).value(); } + /** + * Configuration for the Milvus vector store. + */ + public static class MilvusVectorStoreConfig { + + private final String databaseName; + + private final String collectionName; + + private final int embeddingDimension; + + private final IndexType indexType; + + private final MetricType metricType; + + private final String indexParameters; + + private MilvusVectorStoreConfig(Builder builder) { + this.databaseName = builder.databaseName; + this.collectionName = builder.collectionName; + this.embeddingDimension = builder.embeddingDimension; + this.indexType = builder.indexType; + this.metricType = builder.metricType; + this.indexParameters = builder.indexParameters; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + + return new Builder(); + } + + /** + * {@return the default config} + */ + public static MilvusVectorStoreConfig defaultConfig() { + return builder().build(); + } + + public static class Builder { + + private String databaseName = DEFAULT_DATABASE_NAME; + + private String collectionName = DEFAULT_COLLECTION_NAME; + + private int embeddingDimension = INVALID_EMBEDDING_DIMENSION; + + private IndexType indexType = IndexType.IVF_FLAT; + + private MetricType metricType = MetricType.COSINE; + + private String indexParameters = "{\"nlist\":1024}"; + + private Builder() { + } + + /** + * Configures the Milvus metric type to use. Leave {@literal null} or blank to + * use the metric metric: https://milvus.io/docs/metric.md#floating + * @param metricType the metric type to use + * @return this builder + */ + public Builder withMetricType(MetricType metricType) { + Assert.notNull(metricType, "Collection Name must not be empty"); + Assert.isTrue( + metricType == MetricType.IP || metricType == MetricType.L2 || metricType == MetricType.COSINE, + "Only the text metric types IP and L2 are supported"); + + this.metricType = metricType; + return this; + } + + /** + * Configures the Milvus index type to use. Leave {@literal null} or blank to + * use the default index. + * @param indexType the index type to use + * @return this builder + */ + public Builder withIndexType(IndexType indexType) { + this.indexType = indexType; + return this; + } + + /** + * Configures the Milvus index parameters to use. Leave {@literal null} or + * blank to use the default index parameters. + * @param indexParameters the index parameters to use + * @return this builder + */ + public Builder withIndexParameters(String indexParameters) { + this.indexParameters = indexParameters; + return this; + } + + /** + * Configures the Milvus database name to use. Leave {@literal null} or blank + * to use the default database. + * @param databaseName the database name to use + * @return this builder + */ + public Builder withDatabaseName(String databaseName) { + this.databaseName = databaseName; + return this; + } + + /** + * Configures the Milvus collection name to use. Leave {@literal null} or + * blank to use the default collection name. + * @param collectionName the collection name to use + * @return this builder + */ + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * Configures the size of the embedding. Defaults to {@literal 1536}, inline + * with OpenAIs embeddings. + * @param newEmbeddingDimension The dimension of the embedding + * @return this builder + */ + public Builder withEmbeddingDimension(int newEmbeddingDimension) { + + Assert.isTrue(newEmbeddingDimension >= 1 && newEmbeddingDimension <= 32768, + "Dimension has to be withing the boundaries 1 and 32768 (inclusively)"); + + this.embeddingDimension = newEmbeddingDimension; + return this; + } + + /** + * {@return the immutable configuration} + */ + public MilvusVectorStoreConfig build() { + return new MilvusVectorStoreConfig(this); + } + + } + + } + } diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java index 78538ee8f7a..bc60af7fa23 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import io.milvus.client.MilvusServiceClient; @@ -57,38 +58,40 @@ public void explicitlySetDimensions() { .withEmbeddingDimension(explicitDimensions) .build(); - var dim = new MilvusVectorStore(milvusClient, embeddingModel, config, true, new TokenCountBatchingStrategy()) + var dim = new MilvusVectorStore(this.milvusClient, this.embeddingModel, config, true, + new TokenCountBatchingStrategy()) .embeddingDimensions(); assertThat(dim).isEqualTo(explicitDimensions); - verify(embeddingModel, never()).dimensions(); + verify(this.embeddingModel, never()).dimensions(); } @Test public void embeddingModelDimensions() { - when(embeddingModel.dimensions()).thenReturn(969); + when(this.embeddingModel.dimensions()).thenReturn(969); MilvusVectorStoreConfig config = MilvusVectorStoreConfig.builder().build(); - var dim = new MilvusVectorStore(milvusClient, embeddingModel, config, true, new TokenCountBatchingStrategy()) + var dim = new MilvusVectorStore(this.milvusClient, this.embeddingModel, config, true, + new TokenCountBatchingStrategy()) .embeddingDimensions(); assertThat(dim).isEqualTo(969); - verify(embeddingModel, only()).dimensions(); + verify(this.embeddingModel, only()).dimensions(); } @Test public void fallBackToDefaultDimensions() { - when(embeddingModel.dimensions()).thenThrow(new RuntimeException()); + when(this.embeddingModel.dimensions()).thenThrow(new RuntimeException()); - var dim = new MilvusVectorStore(milvusClient, embeddingModel, MilvusVectorStoreConfig.builder().build(), true, - new TokenCountBatchingStrategy()) + var dim = new MilvusVectorStore(this.milvusClient, this.embeddingModel, + MilvusVectorStoreConfig.builder().build(), true, new TokenCountBatchingStrategy()) .embeddingDimensions(); assertThat(dim).isEqualTo(MilvusVectorStore.OPENAI_EMBEDDING_DIMENSION_SIZE); - verify(embeddingModel, only()).dimensions(); + verify(this.embeddingModel, only()).dimensions(); } @ParameterizedTest diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverterTests.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverterTests.java index afab2766a26..dd0b8ef82d7 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; @@ -45,14 +46,14 @@ public class MilvusFilterExpressionConverterTests { @Test public void testEQ() { // country == "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata[\"country\"] == \"BG\""); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr).isEqualTo("metadata[\"genre\"] == \"drama\" && metadata[\"year\"] >= 2020"); @@ -61,7 +62,7 @@ public void tesEqAndGte() { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("metadata[\"genre\"] in [\"comedy\",\"documentary\",\"drama\"]"); } @@ -69,7 +70,7 @@ public void tesIn() { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -80,7 +81,7 @@ public void testNe() { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); @@ -91,7 +92,7 @@ public void testGroup() { @Test public void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -103,7 +104,7 @@ public void tesBoolean() { @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -112,11 +113,11 @@ public void testDecimal() { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata[\"country 1 2 3\"] == \"BG\""); - vectorExpr = converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); + vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata[\"country 1 2 3\"] == \"BG\""); } diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java index ffdcd3c4b0f..8212474bd77 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java index 118b96c671a..98a88e7b7d1 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -31,6 +32,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -45,7 +47,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.milvus.MilvusContainer; import static org.assertj.core.api.Assertions.assertThat; @@ -88,30 +89,31 @@ private void resetCollection(VectorStore vectorStore) { @ValueSource(strings = { "COSINE", "L2", "IP" }) public void addAndSearch(String metricType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType).run(context -> { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType) + .run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + VectorStore vectorStore = context.getBean(VectorStore.class); - resetCollection(vectorStore); + resetCollection(vectorStore); - vectorStore.add(documents); + vectorStore.add(this.documents); - List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); - assertThat(resultDoc.getContent()).contains( - "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); - // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + // Remove all documents from the store + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - assertThat(results).hasSize(0); - }); + results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + assertThat(results).hasSize(0); + }); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -121,135 +123,140 @@ public void searchWithFilters(String metricType) throws InterruptedException { // https://milvus.io/docs/json_data_type.md - contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType).run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType) + .run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); - resetCollection(vectorStore); + resetCollection(vectorStore); - var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", - Map.of("country", "BG", "year", 2020)); - var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", - Map.of("country", "NL")); - var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner", - Map.of("country", "BG", "year", 2023)); + var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2020)); + var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "NL")); + var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2023)); - vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); + vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); - List results = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)); - assertThat(results).hasSize(3); + List results = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)); + assertThat(results).hasSize(3); - results = vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("country == 'NL'")); - assertThat(results).hasSize(1); - assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'NL'")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); - results = vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("country == 'BG'")); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG'")); - assertThat(results).hasSize(2); - assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); - assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); - results = vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("country == 'BG' && year == 2020")); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG' && year == 2020")); - assertThat(results).hasSize(1); - assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("NOT(country == 'BG' && year == 2020)")); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country == 'BG' && year == 2020)")); - assertThat(results).hasSize(2); - assertThat(results.get(0).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); - assertThat(results.get(1).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); - }); + }); } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "L2", "IP" }) public void documentUpdate(String metricType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType).run(context -> { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType) + .run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + VectorStore vectorStore = context.getBean(VectorStore.class); - resetCollection(vectorStore); + resetCollection(vectorStore); - Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", - Collections.singletonMap("meta1", "meta1")); + Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", + Collections.singletonMap("meta1", "meta1")); - vectorStore.add(List.of(document)); + vectorStore.add(List.of(document)); - List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(document.getId()); - assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); + assertThat(resultDoc.getMetadata()).containsKey("meta1"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); - Document sameIdDocument = new Document(document.getId(), - "The World is Big and Salvation Lurks Around the Corner", - Collections.singletonMap("meta2", "meta2")); + Document sameIdDocument = new Document(document.getId(), + "The World is Big and Salvation Lurks Around the Corner", + Collections.singletonMap("meta2", "meta2")); - vectorStore.add(List.of(sameIdDocument)); + vectorStore.add(List.of(sameIdDocument)); - results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5)); + results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5)); - assertThat(results).hasSize(1); - resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(document.getId()); - assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(results).hasSize(1); + resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); - vectorStore.delete(List.of(document.getId())); + vectorStore.delete(List.of(document.getId())); - }); + }); } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "IP" }) public void searchWithThreshold(String metricType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType).run(context -> { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType) + .run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + VectorStore vectorStore = context.getBean(VectorStore.class); - resetCollection(vectorStore); + resetCollection(vectorStore); - vectorStore.add(documents); + vectorStore.add(this.documents); - List fullResult = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); + List fullResult = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List distances = fullResult.stream() + .map(doc -> (Float) doc.getMetadata().get("distance")) + .toList(); - assertThat(distances).hasSize(3); + assertThat(distances).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + float threshold = (distances.get(0) + distances.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); - assertThat(resultDoc.getContent()).contains( - "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); - }); + }); } @SpringBootConfiguration @@ -265,7 +272,7 @@ public VectorStore vectorStore(MilvusServiceClient milvusClient, EmbeddingModel .withCollectionName("test_vector_store") .withDatabaseName("default") .withIndexType(IndexType.IVF_FLAT) - .withMetricType(metricType) + .withMetricType(this.metricType) .build(); return new MilvusVectorStore(milvusClient, embeddingModel, config, true, new TokenCountBatchingStrategy()); } @@ -288,4 +295,4 @@ public EmbeddingModel embeddingModel() { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreObservationIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreObservationIT.java index 4abca78758d..3a8a7bc7e96 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import io.milvus.client.MilvusServiceClient; +import io.milvus.param.ConnectParam; +import io.milvus.param.IndexType; +import io.milvus.param.MetricType; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -40,17 +50,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.milvus.MilvusContainer; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import io.milvus.client.MilvusServiceClient; -import io.milvus.param.ConnectParam; -import io.milvus.param.IndexType; -import io.milvus.param.MetricType; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -85,13 +86,13 @@ public static String getText(String uri) { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-mongodb-atlas-store/pom.xml b/vector-stores/spring-ai-mongodb-atlas-store/pom.xml index c9bde10489e..48689d36d8e 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/pom.xml +++ b/vector-stores/spring-ai-mongodb-atlas-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterExpressionConverter.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterExpressionConverter.java index be2705c6456..c72f51c4f41 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterExpressionConverter.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter; diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java index aa16e3f29be..0cb9f974bbd 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,15 @@ package org.springframework.ai.vectorstore; -import static org.springframework.data.mongodb.core.query.Criteria.where; - import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; +import com.mongodb.MongoCommandException; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -42,9 +43,7 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.util.Assert; -import com.mongodb.MongoCommandException; - -import io.micrometer.observation.ObservationRegistry; +import static org.springframework.data.mongodb.core.query.Criteria.where; /** * @author Chris Smith @@ -119,8 +118,8 @@ public void afterPropertiesSet() throws Exception { } // Create the collection if it does not exist - if (!mongoTemplate.collectionExists(this.config.collectionName)) { - mongoTemplate.createCollection(this.config.collectionName); + if (!this.mongoTemplate.collectionExists(this.config.collectionName)) { + this.mongoTemplate.createCollection(this.config.collectionName); } // Create search index createSearchIndex(); @@ -128,7 +127,7 @@ public void afterPropertiesSet() throws Exception { private void createSearchIndex() { try { - mongoTemplate.executeCommand(createSearchIndexDefinition()); + this.mongoTemplate.executeCommand(createSearchIndexDefinition()); } catch (UncategorizedMongoDbException e) { Throwable cause = e.getCause(); @@ -228,6 +227,15 @@ public List doSimilaritySearch(SearchRequest request) { .toList(); } + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.MONGODB.value(), operationName) + .withCollectionName(this.config.collectionName) + .withDimensions(this.embeddingModel.dimensions()) + .withFieldName(this.config.pathName); + } + public static class MongoDBVectorStoreConfig { private final String collectionName; @@ -324,13 +332,4 @@ public MongoDBVectorStoreConfig build() { } - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - - return VectorStoreObservationContext.builder(VectorStoreProvider.MONGODB.value(), operationName) - .withCollectionName(this.config.collectionName) - .withDimensions(this.embeddingModel.dimensions()) - .withFieldName(this.config.pathName); - } - -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/VectorSearchAggregation.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/VectorSearchAggregation.java index 6888b9e1120..b741193dd09 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/VectorSearchAggregation.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/VectorSearchAggregation.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; import org.bson.Document; + import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.lang.NonNull; @@ -28,15 +30,16 @@ record VectorSearchAggregation(List embeddings, String path, int numCandi @SuppressWarnings("null") @Override public org.bson.Document toDocument(@NonNull AggregationOperationContext context) { - var vectorSearch = new Document("queryVector", embeddings).append("path", path) - .append("numCandidates", numCandidates) - .append("index", index) - .append("limit", count); - if (!filter.isEmpty()) { - vectorSearch.append("filter", Document.parse(filter)); + var vectorSearch = new Document("queryVector", this.embeddings).append("path", this.path) + .append("numCandidates", this.numCandidates) + .append("index", this.index) + .append("limit", this.count); + if (!this.filter.isEmpty()) { + vectorSearch.append("filter", Document.parse(this.filter)); } var doc = new Document("$vectorSearch", vectorSearch); return context.getMappedObject(doc); } -} \ No newline at end of file + +} diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterConverterTest.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterConverterTest.java index 6ab38c55ca1..a8df6929e94 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterConverterTest.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; @@ -45,14 +46,14 @@ public class MongoDBAtlasFilterConverterTest { @Test public void testEQ() { // country == "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"metadata.country\":{$eq:\"BG\"}}"); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr) @@ -62,7 +63,7 @@ public void tesEqAndGte() { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("{\"metadata.genre\":{$in:[\"comedy\",\"documentary\",\"drama\"]}}"); } @@ -70,7 +71,7 @@ public void tesIn() { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -81,7 +82,7 @@ public void testNe() { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); @@ -92,7 +93,7 @@ public void testGroup() { @Test public void testBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -104,7 +105,7 @@ public void testBoolean() { @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -114,11 +115,11 @@ public void testDecimal() { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"metadata.country 1 2 3\":{$eq:\"BG\"}}"); - vectorExpr = converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); + vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"metadata.country 1 2 3\":{$eq:\"BG\"}}"); } diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java index 45c8a140bdf..5ec855ec0cf 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + import com.mongodb.client.MongoClient; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -34,17 +45,6 @@ import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.util.MimeType; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -66,7 +66,7 @@ class MongoDBAtlasVectorStoreIT { @BeforeEach public void beforeEach() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MongoTemplate mongoTemplate = context.getBean(MongoTemplate.class); mongoTemplate.getCollection("vector_store").deleteMany(new org.bson.Document()); }); @@ -74,7 +74,7 @@ public void beforeEach() { @Test void vectorStoreTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); List documents = List.of( @@ -109,7 +109,7 @@ void vectorStoreTest() { @Test void documentUpdateTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", @@ -144,7 +144,7 @@ void documentUpdateTest() { @Test void searchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -228,6 +228,7 @@ public EmbeddingModel embeddingModel() { @Bean public Converter mimeTypeToStringConverter() { return new Converter() { + @Override public String convert(MimeType source) { return source.toString(); @@ -238,6 +239,7 @@ public String convert(MimeType source) { @Bean public Converter stringToMimeTypeConverter() { return new Converter() { + @Override public MimeType convert(String source) { return MimeType.valueOf(source); @@ -253,4 +255,4 @@ public MongoCustomConversions mongoCustomConversions(Converter } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java index ed59c4993aa..946e813e4ce 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java index 3c78d31ea35..b8ddad7b43d 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -23,9 +22,17 @@ import java.util.List; import java.util.Map; +import com.mongodb.client.MongoClient; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -48,15 +55,7 @@ import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.util.MimeType; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import com.mongodb.client.MongoClient; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -93,7 +92,7 @@ public static String getText(String uri) { @BeforeEach public void beforeEach() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MongoTemplate mongoTemplate = context.getBean(MongoTemplate.class); mongoTemplate.getCollection("vector_store").deleteMany(new org.bson.Document()); }); @@ -102,13 +101,13 @@ public void beforeEach() { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Thread.sleep(5000); @@ -212,6 +211,7 @@ public EmbeddingModel embeddingModel() { @Bean public Converter mimeTypeToStringConverter() { return new Converter() { + @Override public String convert(MimeType source) { return source.toString(); @@ -222,6 +222,7 @@ public String convert(MimeType source) { @Bean public Converter stringToMimeTypeConverter() { return new Converter() { + @Override public MimeType convert(String source) { return MimeType.valueOf(source); diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/VectorSearchAggregationTest.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/VectorSearchAggregationTest.java index d217eb54b0c..3e6cce348bb 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/VectorSearchAggregationTest.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/VectorSearchAggregationTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.List; + import org.bson.Document; import org.junit.jupiter.api.Test; -import org.springframework.data.mongodb.core.aggregation.Aggregation; -import java.util.List; +import org.springframework.data.mongodb.core.aggregation.Aggregation; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -60,4 +62,4 @@ void toDocumentWithFilter() { assertEquals(expected, document); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-neo4j-store/pom.xml b/vector-stores/spring-ai-neo4j-store/pom.xml index ae1b5b3aff2..913277180f1 100644 --- a/vector-stores/spring-ai-neo4j-store/pom.xml +++ b/vector-stores/spring-ai-neo4j-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java index 55c169d242f..6d938a80194 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,10 +22,12 @@ import java.util.Optional; import java.util.function.Predicate; +import io.micrometer.observation.ObservationRegistry; import org.neo4j.cypherdsl.support.schema_name.SchemaNames; import org.neo4j.driver.Driver; import org.neo4j.driver.SessionConfig; import org.neo4j.driver.Values; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -40,8 +42,6 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import io.micrometer.observation.ObservationRegistry; - /** * @author Gerrit Meier * @author Michael Simons @@ -51,6 +51,195 @@ */ public class Neo4jVectorStore extends AbstractObservationVectorStore implements InitializingBean { + public static final int DEFAULT_EMBEDDING_DIMENSION = 1536; + + public static final String DEFAULT_LABEL = "Document"; + + public static final String DEFAULT_INDEX_NAME = "spring-ai-document-index"; + + public static final String DEFAULT_EMBEDDING_PROPERTY = "embedding"; + + public static final String DEFAULT_ID_PROPERTY = "id"; + + public static final String DEFAULT_CONSTRAINT_NAME = DEFAULT_LABEL + "_unique_idx"; + + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + Neo4jDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, Neo4jDistanceType.EUCLIDEAN, + VectorStoreSimilarityMetric.EUCLIDEAN); + + private final Neo4jVectorFilterExpressionConverter filterExpressionConverter = new Neo4jVectorFilterExpressionConverter(); + + private final Driver driver; + + private final EmbeddingModel embeddingModel; + + private final Neo4jVectorStoreConfig config; + + private final boolean initializeSchema; + + private final BatchingStrategy batchingStrategy; + + public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVectorStoreConfig config, + boolean initializeSchema) { + this(driver, embeddingModel, config, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); + } + + public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVectorStoreConfig config, + boolean initializeSchema, ObservationRegistry observationRegistry, + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { + super(observationRegistry, customObservationConvention); + + this.initializeSchema = initializeSchema; + Assert.notNull(driver, "Neo4j driver must not be null"); + Assert.notNull(embeddingModel, "Embedding model must not be null"); + this.driver = driver; + this.embeddingModel = embeddingModel; + this.config = config; + this.batchingStrategy = batchingStrategy; + } + + @Override + public void doAdd(List documents) { + + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + + var rows = documents.stream().map(this::documentToRecord).toList(); + + try (var session = this.driver.session()) { + var statement = """ + UNWIND $rows AS row + MERGE (u:%s {%2$s: row.id}) + ON CREATE + SET u += row.properties + ON MATCH + SET u = {} + SET u.%2$s = row.id, + u += row.properties + WITH row, u + CALL db.create.setNodeVectorProperty(u, $embeddingProperty, row.embedding) + """.formatted(this.config.label, this.config.idProperty); + session.run(statement, Map.of("rows", rows, "embeddingProperty", this.config.embeddingProperty)).consume(); + } + } + + @Override + public Optional doDelete(List idList) { + + try (var session = this.driver.session(this.config.sessionConfig)) { + + var summary = session + .run(""" + MATCH (n:%s) WHERE n.%s IN $ids + CALL { WITH n DETACH DELETE n } IN TRANSACTIONS OF $transactionSize ROWS + """.formatted(this.config.label, this.config.idProperty), + Map.of("ids", idList, "transactionSize", 10_000)) + .consume(); + return Optional.of(idList.size() == summary.counters().nodesDeleted()); + } + } + + @Override + public List doSimilaritySearch(SearchRequest request) { + Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero"); + Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1, + "The similarity score is bounded between 0 and 1; least to most similar respectively."); + + var embedding = Values.value(this.embeddingModel.embed(request.getQuery())); + try (var session = this.driver.session(this.config.sessionConfig)) { + StringBuilder condition = new StringBuilder("score >= $threshold"); + if (request.hasFilterExpression()) { + condition.append(" AND ") + .append(this.filterExpressionConverter.convertExpression(request.getFilterExpression())); + } + String query = """ + CALL db.index.vector.queryNodes($indexName, $numberOfNearestNeighbours, $embeddingValue) + YIELD node, score + WHERE %s + RETURN node, score""".formatted(condition); + + return session + .run(query, Map.of("indexName", this.config.indexNameNotSanitized, "numberOfNearestNeighbours", + request.getTopK(), "embeddingValue", embedding, "threshold", request.getSimilarityThreshold())) + .list(this::recordToDocument); + } + } + + @Override + public void afterPropertiesSet() { + + if (!this.initializeSchema) { + return; + } + + try (var session = this.driver.session(this.config.sessionConfig)) { + + session + .run("CREATE CONSTRAINT %s IF NOT EXISTS FOR (n:%s) REQUIRE n.%s IS UNIQUE" + .formatted(this.config.constraintName, this.config.label, this.config.idProperty)) + .consume(); + + var statement = """ + CREATE VECTOR INDEX %s IF NOT EXISTS FOR (n:%s) ON (n.%s) + OPTIONS {indexConfig: { + `vector.dimensions`: %d, + `vector.similarity_function`: '%s' + }} + """.formatted(this.config.indexName, this.config.label, this.config.embeddingProperty, + this.config.embeddingDimension, this.config.distanceType.name); + session.run(statement).consume(); + session.run("CALL db.awaitIndexes()").consume(); + } + } + + private Map documentToRecord(Document document) { + document.setEmbedding(document.getEmbedding()); + + var row = new HashMap(); + + row.put("id", document.getId()); + + var properties = new HashMap(); + properties.put("text", document.getContent()); + + document.getMetadata().forEach((k, v) -> properties.put("metadata." + k, Values.value(v))); + row.put("properties", properties); + + row.put(this.config.embeddingProperty, Values.value(document.getEmbedding())); + return row; + } + + private Document recordToDocument(org.neo4j.driver.Record neoRecord) { + var node = neoRecord.get("node").asNode(); + var score = neoRecord.get("score").asFloat(); + var metaData = new HashMap(); + metaData.put("distance", 1 - score); + node.keys().forEach(key -> { + if (key.startsWith("metadata.")) { + metaData.put(key.substring(key.indexOf(".") + 1), node.get(key).asObject()); + } + }); + + return new Document(node.get(this.config.idProperty).asString(), node.get("text").asString(), + Map.copyOf(metaData)); + } + + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.NEO4J.value(), operationName) + .withCollectionName(this.config.indexName) + .withDimensions(this.embeddingModel.dimensions()) + .withSimilarityMetric(getSimilarityMetric()); + } + + private String getSimilarityMetric() { + if (!SIMILARITY_TYPE_MAPPING.containsKey(this.config.distanceType)) { + return this.config.distanceType.name(); + } + return SIMILARITY_TYPE_MAPPING.get(this.config.distanceType).value(); + } + /** * An enum to configure the distance function used in the Neo4j vector index. */ @@ -90,6 +279,22 @@ public static final class Neo4jVectorStoreConfig { private final String constraintName; + private Neo4jVectorStoreConfig(Builder builder) { + + this.sessionConfig = Optional.ofNullable(builder.databaseName) + .filter(Predicate.not(String::isBlank)) + .map(SessionConfig::forDatabase) + .orElseGet(SessionConfig::defaultConfig); + this.embeddingDimension = builder.embeddingDimension; + this.distanceType = builder.distanceType; + this.embeddingProperty = SchemaNames.sanitize(builder.embeddingProperty).orElseThrow(); + this.label = SchemaNames.sanitize(builder.label).orElseThrow(); + this.indexNameNotSanitized = builder.indexName; + this.indexName = SchemaNames.sanitize(builder.indexName, true).orElseThrow(); + this.constraintName = SchemaNames.sanitize(builder.constraintName).orElseThrow(); + this.idProperty = SchemaNames.sanitize(builder.idProperty).orElseThrow(); + } + /** * Start building a new configuration. * @return The entry point for creating a new configuration. @@ -107,22 +312,6 @@ public static Neo4jVectorStoreConfig defaultConfig() { return builder().build(); } - private Neo4jVectorStoreConfig(Builder builder) { - - this.sessionConfig = Optional.ofNullable(builder.databaseName) - .filter(Predicate.not(String::isBlank)) - .map(SessionConfig::forDatabase) - .orElseGet(SessionConfig::defaultConfig); - this.embeddingDimension = builder.embeddingDimension; - this.distanceType = builder.distanceType; - this.embeddingProperty = SchemaNames.sanitize(builder.embeddingProperty).orElseThrow(); - this.label = SchemaNames.sanitize(builder.label).orElseThrow(); - this.indexNameNotSanitized = builder.indexName; - this.indexName = SchemaNames.sanitize(builder.indexName, true).orElseThrow(); - this.constraintName = SchemaNames.sanitize(builder.constraintName).orElseThrow(); - this.idProperty = SchemaNames.sanitize(builder.idProperty).orElseThrow(); - } - public static class Builder { private String databaseName; @@ -267,193 +456,4 @@ public Neo4jVectorStoreConfig build() { } - public static final int DEFAULT_EMBEDDING_DIMENSION = 1536; - - public static final String DEFAULT_LABEL = "Document"; - - public static final String DEFAULT_INDEX_NAME = "spring-ai-document-index"; - - public static final String DEFAULT_EMBEDDING_PROPERTY = "embedding"; - - public static final String DEFAULT_ID_PROPERTY = "id"; - - public static final String DEFAULT_CONSTRAINT_NAME = DEFAULT_LABEL + "_unique_idx"; - - private final Neo4jVectorFilterExpressionConverter filterExpressionConverter = new Neo4jVectorFilterExpressionConverter(); - - private final Driver driver; - - private final EmbeddingModel embeddingModel; - - private final Neo4jVectorStoreConfig config; - - private final boolean initializeSchema; - - private final BatchingStrategy batchingStrategy; - - public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVectorStoreConfig config, - boolean initializeSchema) { - this(driver, embeddingModel, config, initializeSchema, ObservationRegistry.NOOP, null, - new TokenCountBatchingStrategy()); - } - - public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVectorStoreConfig config, - boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - super(observationRegistry, customObservationConvention); - - this.initializeSchema = initializeSchema; - Assert.notNull(driver, "Neo4j driver must not be null"); - Assert.notNull(embeddingModel, "Embedding model must not be null"); - this.driver = driver; - this.embeddingModel = embeddingModel; - this.config = config; - this.batchingStrategy = batchingStrategy; - } - - @Override - public void doAdd(List documents) { - - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - - var rows = documents.stream().map(this::documentToRecord).toList(); - - try (var session = this.driver.session()) { - var statement = """ - UNWIND $rows AS row - MERGE (u:%s {%2$s: row.id}) - ON CREATE - SET u += row.properties - ON MATCH - SET u = {} - SET u.%2$s = row.id, - u += row.properties - WITH row, u - CALL db.create.setNodeVectorProperty(u, $embeddingProperty, row.embedding) - """.formatted(this.config.label, this.config.idProperty); - session.run(statement, Map.of("rows", rows, "embeddingProperty", this.config.embeddingProperty)).consume(); - } - } - - @Override - public Optional doDelete(List idList) { - - try (var session = this.driver.session(this.config.sessionConfig)) { - - var summary = session - .run(""" - MATCH (n:%s) WHERE n.%s IN $ids - CALL { WITH n DETACH DELETE n } IN TRANSACTIONS OF $transactionSize ROWS - """.formatted(this.config.label, this.config.idProperty), - Map.of("ids", idList, "transactionSize", 10_000)) - .consume(); - return Optional.of(idList.size() == summary.counters().nodesDeleted()); - } - } - - @Override - public List doSimilaritySearch(SearchRequest request) { - Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero"); - Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1, - "The similarity score is bounded between 0 and 1; least to most similar respectively."); - - var embedding = Values.value(this.embeddingModel.embed(request.getQuery())); - try (var session = this.driver.session(this.config.sessionConfig)) { - StringBuilder condition = new StringBuilder("score >= $threshold"); - if (request.hasFilterExpression()) { - condition.append(" AND ") - .append(this.filterExpressionConverter.convertExpression(request.getFilterExpression())); - } - String query = """ - CALL db.index.vector.queryNodes($indexName, $numberOfNearestNeighbours, $embeddingValue) - YIELD node, score - WHERE %s - RETURN node, score""".formatted(condition); - - return session - .run(query, Map.of("indexName", this.config.indexNameNotSanitized, "numberOfNearestNeighbours", - request.getTopK(), "embeddingValue", embedding, "threshold", request.getSimilarityThreshold())) - .list(this::recordToDocument); - } - } - - @Override - public void afterPropertiesSet() { - - if (!this.initializeSchema) { - return; - } - - try (var session = this.driver.session(this.config.sessionConfig)) { - - session - .run("CREATE CONSTRAINT %s IF NOT EXISTS FOR (n:%s) REQUIRE n.%s IS UNIQUE" - .formatted(this.config.constraintName, this.config.label, this.config.idProperty)) - .consume(); - - var statement = """ - CREATE VECTOR INDEX %s IF NOT EXISTS FOR (n:%s) ON (n.%s) - OPTIONS {indexConfig: { - `vector.dimensions`: %d, - `vector.similarity_function`: '%s' - }} - """.formatted(this.config.indexName, this.config.label, this.config.embeddingProperty, - this.config.embeddingDimension, this.config.distanceType.name); - session.run(statement).consume(); - session.run("CALL db.awaitIndexes()").consume(); - } - } - - private Map documentToRecord(Document document) { - document.setEmbedding(document.getEmbedding()); - - var row = new HashMap(); - - row.put("id", document.getId()); - - var properties = new HashMap(); - properties.put("text", document.getContent()); - - document.getMetadata().forEach((k, v) -> properties.put("metadata." + k, Values.value(v))); - row.put("properties", properties); - - row.put(this.config.embeddingProperty, Values.value(document.getEmbedding())); - return row; - } - - private Document recordToDocument(org.neo4j.driver.Record neoRecord) { - var node = neoRecord.get("node").asNode(); - var score = neoRecord.get("score").asFloat(); - var metaData = new HashMap(); - metaData.put("distance", 1 - score); - node.keys().forEach(key -> { - if (key.startsWith("metadata.")) { - metaData.put(key.substring(key.indexOf(".") + 1), node.get(key).asObject()); - } - }); - - return new Document(node.get(this.config.idProperty).asString(), node.get("text").asString(), - Map.copyOf(metaData)); - } - - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - - return VectorStoreObservationContext.builder(VectorStoreProvider.NEO4J.value(), operationName) - .withCollectionName(this.config.indexName) - .withDimensions(this.embeddingModel.dimensions()) - .withSimilarityMetric(getSimilarityMetric()); - } - - private static Map SIMILARITY_TYPE_MAPPING = Map.of( - Neo4jDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, Neo4jDistanceType.EUCLIDEAN, - VectorStoreSimilarityMetric.EUCLIDEAN); - - private String getSimilarityMetric() { - if (!SIMILARITY_TYPE_MAPPING.containsKey(this.config.distanceType)) { - return this.config.distanceType.name(); - } - return SIMILARITY_TYPE_MAPPING.get(this.config.distanceType).value(); - } - -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java index 41ce4f8001d..7321699998d 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java index 3ea60478d38..513fd69433b 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java index 57fcc71797d..3aa43823913 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.Collections; @@ -27,15 +28,15 @@ import org.neo4j.driver.AuthTokens; import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; -import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; import org.testcontainers.containers.Neo4jContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -57,6 +58,9 @@ class Neo4jVectorStoreIT { @Container static Neo4jContainer neo4jContainer = new Neo4jContainer<>(Neo4jImage.DEFAULT_IMAGE).withRandomPassword(); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")), @@ -65,9 +69,6 @@ class Neo4jVectorStoreIT { "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression", Collections.singletonMap("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class); - @BeforeEach void cleanDatabase() { this.contextRunner diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java index ebf09454d27..1c841b1c1da 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -29,6 +31,10 @@ import org.neo4j.driver.AuthTokens; import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; +import org.testcontainers.containers.Neo4jContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -46,13 +52,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.Neo4jContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -92,13 +93,13 @@ void cleanDatabase() { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverterTests.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverterTests.java index 8c8a397f852..4eeaa22f198 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -46,14 +47,14 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void testEQ() { // country = "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("node.`metadata.country` = \"BG\""); } @Test public void tesEqAndGte() { // genre = "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr).isEqualTo("node.`metadata.genre` = \"drama\" AND node.`metadata.year` >= 2020"); @@ -62,7 +63,7 @@ public void tesEqAndGte() { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("node.`metadata.genre` IN [\"comedy\",\"documentary\",\"drama\"]"); } @@ -70,7 +71,7 @@ public void tesIn() { @Test public void tesNIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(NIN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("NOT node.`metadata.genre` IN [\"comedy\",\"documentary\",\"drama\"]"); } @@ -78,7 +79,7 @@ public void tesNIn() { @Test public void testNe() { // year >= 2020 OR country = "BG" AND city <> "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -89,7 +90,7 @@ public void testNe() { @Test public void testGroup() { // (year >= 2020 OR country = "BG") AND NOT city IN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NOT, new Expression(IN, new Key("city"), new Value(List.of("Sofia", "Plovdiv")))))); @@ -100,7 +101,7 @@ public void testGroup() { @Test public void testBoolean() { // isOpen = true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -112,7 +113,7 @@ public void testBoolean() { @Test public void testDecimal() { // temperature >= -15.6 AND temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -122,7 +123,7 @@ public void testDecimal() { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("node.`metadata.country 1 2 3` = \"BG\""); } @@ -131,7 +132,7 @@ public void testComplexIdentifiers() { public void testComplexIdentifiers2() { Filter.Expression expr = new FilterExpressionTextParser() .parse("author in ['john', 'jill'] && 'article_type' == 'blog'"); - String vectorExpr = converter.convertExpression(expr); + String vectorExpr = this.converter.convertExpression(expr); assertThat(vectorExpr) .isEqualTo("node.`metadata.author` IN [\"john\",\"jill\"] AND node.`metadata.'article_type'` = \"blog\""); } diff --git a/vector-stores/spring-ai-opensearch-store/pom.xml b/vector-stores/spring-ai-opensearch-store/pom.xml index 33deb0fdc59..aa3ddfbbd09 100644 --- a/vector-stores/spring-ai-opensearch-store/pom.xml +++ b/vector-stores/spring-ai-opensearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java index 9035a86d299..98876f5e645 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.Filter.Expression; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; +package org.springframework.ai.vectorstore; import java.text.ParseException; import java.text.SimpleDateFormat; @@ -27,6 +23,11 @@ import java.util.TimeZone; import java.util.regex.Pattern; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; + /** * @author Jemin Huh * @since 1.0.0 diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java index 1756ded2ff2..f81b108ccbd 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,14 @@ package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.io.StringReader; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import io.micrometer.observation.ObservationRegistry; import org.opensearch.client.json.JsonData; import org.opensearch.client.json.JsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; @@ -30,6 +38,7 @@ import org.opensearch.client.transport.endpoints.BooleanResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -46,15 +55,6 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import io.micrometer.observation.ObservationRegistry; - -import java.io.IOException; -import java.io.StringReader; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; - /** * @author Jemin Huh * @author Soby Chacko @@ -67,8 +67,6 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem public static final String COSINE_SIMILARITY_FUNCTION = "cosinesimil"; - private static final Logger logger = LoggerFactory.getLogger(OpenSearchVectorStore.class); - public static final String DEFAULT_INDEX_NAME = "spring-ai-document-index"; public static final String DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536 = """ @@ -82,6 +80,8 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem } """; + private static final Logger logger = LoggerFactory.getLogger(OpenSearchVectorStore.class); + private final EmbeddingModel embeddingModel; private final OpenSearchClient openSearchClient; @@ -92,12 +92,12 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem private final String mappingJson; - private String similarityFunction; - private final boolean initializeSchema; private final BatchingStrategy batchingStrategy; + private String similarityFunction; + public OpenSearchVectorStore(OpenSearchClient openSearchClient, EmbeddingModel embeddingModel, boolean initializeSchema) { this(openSearchClient, embeddingModel, DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536, @@ -245,7 +245,7 @@ public boolean exists(String targetIndex) { } private CreateIndexResponse createIndexMapping(String index, String mappingJson) { - JsonpMapper jsonpMapper = openSearchClient._transport().jsonpMapper(); + JsonpMapper jsonpMapper = this.openSearchClient._transport().jsonpMapper(); try { return this.openSearchClient.indices() .create(new CreateIndexRequest.Builder().index(index) @@ -285,4 +285,4 @@ else if ("l2".equalsIgnoreCase(this.similarityFunction)) { return this.similarityFunction; } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java index e830fa54520..77e2a95a0bb 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Date; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; @@ -25,38 +34,31 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; -import java.util.Date; -import java.util.List; - -import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; - class OpenSearchAiSearchFilterExpressionConverterTest { final FilterExpressionConverter converter = new OpenSearchAiSearchFilterExpressionConverter(); @Test public void testDate() { - String vectorExpr = converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), + String vectorExpr = this.converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value(new Date(1704637752148L)))); assertThat(vectorExpr).isEqualTo("metadata.activationDate:2024-01-07T14:29:12Z"); - vectorExpr = converter.convertExpression( + vectorExpr = this.converter.convertExpression( new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value("1970-01-01T00:00:02Z"))); assertThat(vectorExpr).isEqualTo("metadata.activationDate:1970-01-01T00:00:02Z"); } @Test public void testEQ() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country:BG"); } @Test public void tesEqAndGte() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("genre"), new Filter.Value("drama")), new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)))); assertThat(vectorExpr).isEqualTo("metadata.genre:drama AND metadata.year:>=2020"); @@ -64,14 +66,14 @@ public void tesEqAndGte() { @Test public void tesIn() { - String vectorExpr = converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), + String vectorExpr = this.converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), new Filter.Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("(metadata.genre:comedy OR documentary OR drama)"); } @Test public void testNe() { - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")), @@ -81,7 +83,7 @@ public void testNe() { @Test public void testGroup() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Group(new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")))), @@ -92,7 +94,7 @@ public void testGroup() { @Test public void tesBoolean() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("isOpen"), new Filter.Value(true)), new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020))), new Filter.Expression(IN, new Filter.Key("country"), new Filter.Value(List.of("BG", "NL", "US"))))); @@ -103,7 +105,7 @@ public void tesBoolean() { @Test public void testDecimal() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(GTE, new Filter.Key("temperature"), new Filter.Value(-15.6)), new Filter.Expression(LTE, new Filter.Key("temperature"), new Filter.Value(20.13)))); @@ -112,11 +114,11 @@ public void testDecimal() { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("\"country 1 2 3\""), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); - vectorExpr = converter + vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("'country 1 2 3'"), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); } diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java index dea664624a5..294ed87d5c1 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java index 645bebd4aa9..a38207fb43c 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,17 @@ package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.ZonedDateTime; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + import org.apache.hc.core5.http.HttpHost; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; @@ -27,6 +38,9 @@ import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; import org.opensearch.testcontainers.OpensearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -38,19 +52,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.io.IOException; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.ZonedDateTime; -import java.util.Date; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.TimeUnit; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -126,7 +127,7 @@ public void addAndSearchTest(String similarityFunction) { vectorStore.withSimilarityFunction(similarityFunction); } - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore @@ -138,14 +139,14 @@ public void addAndSearchTest(String similarityFunction) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore @@ -245,7 +246,7 @@ public void searchWithFilters(String similarityFunction) { assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)), hasSize(0)); @@ -318,7 +319,7 @@ public void searchThresholdTest(String similarityFunction) { vectorStore.withSimilarityFunction(similarityFunction); } - vectorStore.add(documents); + vectorStore.add(this.documents); SearchRequest query = SearchRequest.query("Great Depression") .withTopK(50) @@ -339,13 +340,13 @@ public void searchThresholdTest(String similarityFunction) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java index 45298605d50..7ce5101a110 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.net.URISyntaxException; @@ -25,6 +24,9 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.apache.hc.core5.http.HttpHost; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; @@ -34,6 +36,9 @@ import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; import org.opensearch.testcontainers.OpensearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -51,13 +56,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.hasSize; /** @@ -87,10 +87,6 @@ public static String getText(String uri) { } } - private ApplicationContextRunner getContextRunner() { - return new ApplicationContextRunner().withUserConfiguration(Config.class); - } - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -98,6 +94,10 @@ public static void beforeAll() { Awaitility.setDefaultTimeout(Duration.ofMinutes(1)); } + private ApplicationContextRunner getContextRunner() { + return new ApplicationContextRunner().withUserConfiguration(Config.class); + } + @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { @@ -115,7 +115,7 @@ void observationVectorStoreAddAndQueryOperations() { TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() @@ -182,7 +182,7 @@ void observationVectorStoreAddAndQueryOperations() { observationRegistry.clear(); - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore diff --git a/vector-stores/spring-ai-oracle-store/pom.xml b/vector-stores/spring-ai-oracle-store/pom.xml index b95d5ee0129..7335b4d8833 100644 --- a/vector-stores/spring-ai-oracle-store/pom.xml +++ b/vector-stores/spring-ai-oracle-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java index 290f570f946..a32a904eb87 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,9 +16,6 @@ package org.springframework.ai.vectorstore; -import static org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT; -import static org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue; - import java.io.ByteArrayOutputStream; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -31,8 +28,16 @@ import java.util.Map; import java.util.Optional; +import io.micrometer.observation.ObservationRegistry; +import oracle.jdbc.OracleType; +import oracle.sql.VECTOR; +import oracle.sql.json.OracleJsonFactory; +import oracle.sql.json.OracleJsonGenerator; +import oracle.sql.json.OracleJsonObject; +import oracle.sql.json.OracleJsonValue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -51,13 +56,8 @@ import org.springframework.jdbc.core.RowMapper; import org.springframework.util.StringUtils; -import io.micrometer.observation.ObservationRegistry; -import oracle.jdbc.OracleType; -import oracle.sql.VECTOR; -import oracle.sql.json.OracleJsonFactory; -import oracle.sql.json.OracleJsonGenerator; -import oracle.sql.json.OracleJsonObject; -import oracle.sql.json.OracleJsonValue; +import static org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT; +import static org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue; /** *

    @@ -86,94 +86,8 @@ */ public class OracleVectorStore extends AbstractObservationVectorStore implements InitializingBean { - private static final Logger logger = LoggerFactory.getLogger(OracleVectorStore.class); - public static final double SIMILARITY_THRESHOLD_EXACT_MATCH = 1.0d; - public enum OracleVectorStoreIndexType { - - /** - * Performs exact nearest neighbor search. - */ - NONE, - - /** - *

    - * The default type of index created for an In-Memory Neighbor Graph vector index - * is Hierarchical Navigable Small World (HNSW). - *

    - * - *

    - * With Navigable Small World (NSW), the idea is to build a proximity graph where - * each vector in the graph connects to several others based on three - * characteristics: - *

      - *
    • The distance between vectors
    • - *
    • The maximum number of closest vector candidates considered at each step of - * the search during insertion (EFCONSTRUCTION)
    • - *
    • Within the maximum number of connections (NEIGHBORS) permitted per - * vector
    • - *
    - * - * @see Oracle - * Database documentation - */ - HNSW, - - /** - *

    - * The default type of index created for a Neighbor Partition vector index is - * Inverted File Flat (IVF) vector index. The IVF index is a technique designed to - * enhance search efficiency by narrowing the search area through the use of - * neighbor partitions or clusters. - *

    - * - * * @see Oracle - * Database documentation - */ - IVF; - - } - - public enum OracleVectorStoreDistanceType { - - /** - * Default metric. It calculates the cosine distance between two vectors. - */ - COSINE, - - /** - * Also called the inner product, calculates the negated dot product of two - * vectors. - */ - DOT, - - /** - * Also called L2_DISTANCE, calculates the Euclidean distance between two vectors. - */ - EUCLIDEAN, - - /** - * Also called L2_SQUARED is the Euclidean distance without taking the square - * root. - */ - EUCLIDEAN_SQUARED, - - /* - * Calculates the hamming distance between two vectors. Requires INT8 element - * type. - */ - // TODO: add HAMMING support, - - /** - * Also called L1_DISTANCE or taxicab distance, calculates the Manhattan distance. - */ - MANHATTAN - - } - public static final String DEFAULT_TABLE_NAME = "SPRING_AI_VECTORS"; public static final OracleVectorStoreIndexType DEFAULT_INDEX_TYPE = OracleVectorStoreIndexType.IVF; @@ -184,6 +98,15 @@ public enum OracleVectorStoreDistanceType { public static final int DEFAULT_SEARCH_ACCURACY = -1; + private static final Logger logger = LoggerFactory.getLogger(OracleVectorStore.class); + + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + OracleVectorStoreDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, + OracleVectorStoreDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, + OracleVectorStoreDistanceType.DOT, VectorStoreSimilarityMetric.DOT); + + public final FilterExpressionConverter filterExpressionConverter = new SqlJsonPathFilterExpressionConverter(); + private final JdbcTemplate jdbcTemplate; private final EmbeddingModel embeddingModel; @@ -192,8 +115,6 @@ public enum OracleVectorStoreDistanceType { private final boolean removeExistingVectorStoreTable; - public final FilterExpressionConverter filterExpressionConverter = new SqlJsonPathFilterExpressionConverter(); - /** * Table name where vectors will be stored. */ @@ -222,6 +143,10 @@ public enum OracleVectorStoreDistanceType { private final BatchingStrategy batchingStrategy; + private final OracleJsonFactory osonFactory = new OracleJsonFactory(); + + private final ByteArrayOutputStream out = new ByteArrayOutputStream(); + public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { this(jdbcTemplate, embeddingModel, DEFAULT_TABLE_NAME, DEFAULT_INDEX_TYPE, DEFAULT_DISTANCE_TYPE, DEFAULT_DIMENSIONS, DEFAULT_SEARCH_ACCURACY, false, false, false); @@ -284,6 +209,7 @@ public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMode public void doAdd(final List documents) { this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); this.jdbcTemplate.batchUpdate(getIngestStatement(), new BatchPreparedStatementSetter() { + @Override public void setValues(PreparedStatement ps, int i) throws SQLException { final Document document = documents.get(i); @@ -310,21 +236,17 @@ private String getIngestStatement() { merge into %s target using (values(?, ?, ?, ?)) source (id, content, metadata, embedding) on (target.id = source.id) when matched then update set target.content = source.content, target.metadata = source.metadata, target.embedding = source.embedding when not matched then insert (target.id, target.content, target.metadata, target.embedding) values (source.id, source.content, source.metadata, source.embedding)""", - tableName); + this.tableName); } - private final OracleJsonFactory osonFactory = new OracleJsonFactory(); - - private final ByteArrayOutputStream out = new ByteArrayOutputStream(); - /** * Bind binary JSON from the client. * @param m map of metadata * @return the binary JSON ready to be inserted */ private byte[] toJson(final Map m) { - out.reset(); - try (OracleJsonGenerator gen = osonFactory.createJsonBinaryGenerator(out)) { + this.out.reset(); + try (OracleJsonGenerator gen = this.osonFactory.createJsonBinaryGenerator(this.out)) { gen.writeStartObject(); for (String key : m.keySet()) { final Object o = m.get(key); @@ -347,7 +269,7 @@ else if (o instanceof Boolean) { gen.writeEnd(); } - return out.toByteArray(); + return this.out.toByteArray(); } /** @@ -364,7 +286,7 @@ private VECTOR toVECTOR(final float[] floatList) throws SQLException { doubles[i++] = d; } - if (forcedNormalization) { + if (this.forcedNormalization) { return VECTOR.ofFloat64Values(normalize(doubles)); } @@ -398,7 +320,7 @@ private double[] normalize(final double[] v) { @Override public Optional doDelete(final List idList) { - final String sql = String.format("delete from %s where id=?", tableName); + final String sql = String.format("delete from %s where id=?", this.tableName); final int[] argTypes = { Types.VARCHAR }; final List batchArgs = new ArrayList<>(); @@ -406,7 +328,7 @@ public Optional doDelete(final List idList) { batchArgs.add(new Object[] { id }); } - final int[] deleteCounts = jdbcTemplate.batchUpdate(sql, batchArgs, argTypes); + final int[] deleteCounts = this.jdbcTemplate.batchUpdate(sql, batchArgs, argTypes); int deleteCount = 0; for (int detailedResult : deleteCounts) { @@ -423,51 +345,16 @@ public Optional doDelete(final List idList) { return Optional.of(deleteCount == idList.size()); } - private static class DocumentRowMapper implements RowMapper { - - @Override - public Document mapRow(ResultSet rs, int rowNum) throws SQLException { - final Map metadata = getMap(rs.getObject(3, OracleJsonValue.class)); - metadata.put("distance", rs.getDouble(5)); - - final Document document = new Document(rs.getString(1), rs.getString(2), metadata); - final float[] embedding = rs.getObject(4, float[].class); - document.setEmbedding(embedding); - return document; - } - - private Map getMap(OracleJsonValue value) { - final Map result = new HashMap<>(); - - if (value != null) { - final OracleJsonObject json = value.asJsonObject(); - for (String key : json.keySet()) { - result.put(key, json.get(key)); - } - } - - return result; - } - - private List toFloatList(final float[] embeddings) { - final List result = new ArrayList<>(embeddings.length); - for (float v : embeddings) { - result.add(v); - } - return result; - } - - } - @Override public List doSimilaritySearch(SearchRequest request) { try { // From the provided query, generate a vector using the embedding model - final VECTOR embeddingVector = toVECTOR(embeddingModel.embed(request.getQuery())); + final VECTOR embeddingVector = toVECTOR(this.embeddingModel.embed(request.getQuery())); if (logger.isDebugEnabled()) { this.jdbcTemplate.batchUpdate("insert into debug(embedding) values(?)", new BatchPreparedStatementSetter() { + @Override public void setValues(PreparedStatement ps, int i) throws SQLException { setParameterValue(ps, 1, OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); @@ -490,20 +377,21 @@ public int getBatchSize() { jsonPathFilter = String.format("where JSON_EXISTS( metadata, '%s' )\n", nativeFilterExpression); } - final String sql = searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" + final String sql = this.searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance from %s %sorder by distance - fetch first %d rows only""", distanceType == DOT ? "(1+" : "", distanceType.name(), - distanceType == DOT ? ")/2" : "", tableName, jsonPathFilter, request.getTopK()) + fetch first %d rows only""", this.distanceType == DOT ? "(1+" : "", this.distanceType.name(), + this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, request.getTopK()) : String.format( """ select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance from %s %sorder by distance fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", - distanceType == DOT ? "(1+" : "", distanceType.name(), distanceType == DOT ? ")/2" : "", - tableName, jsonPathFilter, request.getTopK(), searchAccuracy); + this.distanceType == DOT ? "(1+" : "", this.distanceType.name(), + this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, + request.getTopK(), this.searchAccuracy); logger.debug("SQL query: " + sql); @@ -518,52 +406,60 @@ else if (request.getSimilarityThreshold() == SIMILARITY_THRESHOLD_EXACT_MATCH) { select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance from %s %sorder by distance - fetch EXACT first %d rows only""", distanceType == DOT ? "(1+" : "", distanceType.name(), - distanceType == DOT ? ")/2" : "", tableName, jsonPathFilter, request.getTopK()); + fetch EXACT first %d rows only""", this.distanceType == DOT ? "(1+" : "", + this.distanceType.name(), this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, + request.getTopK()); logger.debug("SQL query: " + sql); return this.jdbcTemplate.query(sql, new DocumentRowMapper(), embeddingVector); } else { - if (!forcedNormalization - || (distanceType != OracleVectorStoreDistanceType.COSINE && distanceType != DOT)) { + if (!this.forcedNormalization + || (this.distanceType != OracleVectorStoreDistanceType.COSINE && this.distanceType != DOT)) { throw new RuntimeException( "Similarity threshold filtering requires all vectors to be normalized, see the forcedNormalization parameter for this Vector store. Also only COSINE and DOT distance types are supported."); } - final double distance = distanceType == DOT ? (1d - request.getSimilarityThreshold()) * 2d - 1d + final double distance = this.distanceType == DOT ? (1d - request.getSimilarityThreshold()) * 2d - 1d : 1d - request.getSimilarityThreshold(); if (StringUtils.hasText(nativeFilterExpression)) { jsonPathFilter = String.format(" and JSON_EXISTS( metadata, '%s' )", nativeFilterExpression); } - final String sql = distanceType == DOT ? (searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" - select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance - from %s - where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s - order by distance - fetch first %d rows only""", tableName, jsonPathFilter, request.getTopK()) : String.format(""" - select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance - from %s - where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s - order by distance - fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", tableName, jsonPathFilter, - request.getTopK(), searchAccuracy) + final String sql = this.distanceType == DOT ? (this.searchAccuracy == DEFAULT_SEARCH_ACCURACY + ? String.format( + """ + select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance + from %s + where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s + order by distance + fetch first %d rows only""", + this.tableName, jsonPathFilter, request.getTopK()) + : String.format( + """ + select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance + from %s + where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s + order by distance + fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", + this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy) - ) : (searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" + ) : (this.searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance from %s where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s order by distance - fetch first %d rows only""", tableName, jsonPathFilter, request.getTopK()) : String.format(""" - select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance - from %s - where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s - order by distance - fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", tableName, jsonPathFilter, - request.getTopK(), searchAccuracy)); + fetch first %d rows only""", this.tableName, jsonPathFilter, request.getTopK()) + : String.format( + """ + select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance + from %s + where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s + order by distance + fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", + this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy)); logger.debug("SQL query: " + sql); @@ -581,7 +477,7 @@ public void afterPropertiesSet() throws Exception { if (this.initializeSchema) { // Remove existing VectorStoreTable if (this.removeExistingVectorStoreTable) { - this.jdbcTemplate.execute(String.format("drop table if exists %s purge", tableName)); + this.jdbcTemplate.execute(String.format("drop table if exists %s purge", this.tableName)); } this.jdbcTemplate.execute(String.format(""" @@ -590,27 +486,28 @@ id varchar2(36) default sys_guid() primary key, content clob not null, metadata json not null, embedding vector(%s,FLOAT64) annotations(Distance '%s', IndexType '%s') - )""", tableName, dimensions == DEFAULT_DIMENSIONS ? "*" : String.valueOf(dimensions), - distanceType.name(), indexType.name())); + )""", this.tableName, this.dimensions == DEFAULT_DIMENSIONS ? "*" : String.valueOf(this.dimensions), + this.distanceType.name(), this.indexType.name())); if (logger.isDebugEnabled()) { this.jdbcTemplate.execute(String.format(""" create table if not exists debug ( id varchar2(36) default sys_guid() primary key, embedding vector(%s,FLOAT64) annotations(Distance '%s') - )""", dimensions == DEFAULT_DIMENSIONS ? "*" : String.valueOf(dimensions), - distanceType.name())); + )""", this.dimensions == DEFAULT_DIMENSIONS ? "*" : String.valueOf(this.dimensions), + this.distanceType.name())); } - switch (indexType) { + switch (this.indexType) { case IVF: this.jdbcTemplate.execute(String.format(""" create vector index if not exists vector_index_%s on %s (embedding) organization neighbor partitions distance %s with target accuracy %d - parameters (type IVF, neighbor partitions 10)""", tableName, tableName, - distanceType.name(), searchAccuracy == DEFAULT_SEARCH_ACCURACY ? 95 : searchAccuracy)); + parameters (type IVF, neighbor partitions 10)""", this.tableName, + this.tableName, this.distanceType.name(), + this.searchAccuracy == DEFAULT_SEARCH_ACCURACY ? 95 : this.searchAccuracy)); break; /* @@ -627,7 +524,7 @@ embedding vector(%s,FLOAT64) annotations(Distance '%s') } public String getTableName() { - return tableName; + return this.tableName; } @Override @@ -638,11 +535,6 @@ public Builder createObservationContextBuilder(String operationName) { .withSimilarityMetric(getSimilarityMetric()); } - private static Map SIMILARITY_TYPE_MAPPING = Map.of( - OracleVectorStoreDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, - OracleVectorStoreDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, - OracleVectorStoreDistanceType.DOT, VectorStoreSimilarityMetric.DOT); - private String getSimilarityMetric() { if (!SIMILARITY_TYPE_MAPPING.containsKey(this.distanceType)) { return this.distanceType.name(); @@ -650,4 +542,124 @@ private String getSimilarityMetric() { return SIMILARITY_TYPE_MAPPING.get(this.distanceType).value(); } + public enum OracleVectorStoreIndexType { + + /** + * Performs exact nearest neighbor search. + */ + NONE, + + /** + *

    + * The default type of index created for an In-Memory Neighbor Graph vector index + * is Hierarchical Navigable Small World (HNSW). + *

    + * + *

    + * With Navigable Small World (NSW), the idea is to build a proximity graph where + * each vector in the graph connects to several others based on three + * characteristics: + *

      + *
    • The distance between vectors
    • + *
    • The maximum number of closest vector candidates considered at each step of + * the search during insertion (EFCONSTRUCTION)
    • + *
    • Within the maximum number of connections (NEIGHBORS) permitted per + * vector
    • + *
    + * + * @see Oracle + * Database documentation + */ + HNSW, + + /** + *

    + * The default type of index created for a Neighbor Partition vector index is + * Inverted File Flat (IVF) vector index. The IVF index is a technique designed to + * enhance search efficiency by narrowing the search area through the use of + * neighbor partitions or clusters. + *

    + * + * * @see Oracle + * Database documentation + */ + IVF; + + } + + public enum OracleVectorStoreDistanceType { + + /** + * Default metric. It calculates the cosine distance between two vectors. + */ + COSINE, + + /** + * Also called the inner product, calculates the negated dot product of two + * vectors. + */ + DOT, + + /** + * Also called L2_DISTANCE, calculates the Euclidean distance between two vectors. + */ + EUCLIDEAN, + + /** + * Also called L2_SQUARED is the Euclidean distance without taking the square + * root. + */ + EUCLIDEAN_SQUARED, + + /* + * Calculates the hamming distance between two vectors. Requires INT8 element + * type. + */ + // TODO: add HAMMING support, + + /** + * Also called L1_DISTANCE or taxicab distance, calculates the Manhattan distance. + */ + MANHATTAN + + } + + private static class DocumentRowMapper implements RowMapper { + + @Override + public Document mapRow(ResultSet rs, int rowNum) throws SQLException { + final Map metadata = getMap(rs.getObject(3, OracleJsonValue.class)); + metadata.put("distance", rs.getDouble(5)); + + final Document document = new Document(rs.getString(1), rs.getString(2), metadata); + final float[] embedding = rs.getObject(4, float[].class); + document.setEmbedding(embedding); + return document; + } + + private Map getMap(OracleJsonValue value) { + final Map result = new HashMap<>(); + + if (value != null) { + final OracleJsonObject json = value.asJsonObject(); + for (String key : json.keySet()) { + result.put(key, json.get(key)); + } + } + + return result; + } + + private List toFloatList(final float[] embeddings) { + final List result = new ArrayList<>(embeddings.length); + for (float v : embeddings) { + result.add(v); + } + return result; + } + + } + } diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverter.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverter.java index ad3432b7cf0..0fd446257e3 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverter.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverter.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter; diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java index 3f2250a3c68..6955a660680 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java index b78790ecd5e..638db84b57f 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java @@ -1,10 +1,41 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import javax.sql.DataSource; + import oracle.jdbc.pool.OracleDataSource; import org.junit.Assert; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.oracle.OracleContainer; +import org.testcontainers.utility.MountableFile; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -22,19 +53,6 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.util.CollectionUtils; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.oracle.OracleContainer; -import org.testcontainers.utility.MountableFile; - -import javax.sql.DataSource; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.vectorstore.OracleVectorStore.DEFAULT_SEARCH_ACCURACY; @@ -51,15 +69,6 @@ public class OracleVectorStoreIT { new Document(getText("classpath:/test/data/time.shelter.txt")), new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - public static String getText(final String uri) { - try { - return new DefaultResourceLoader().getResource(uri).getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(TestClient.class) .withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=COSINE", @@ -70,80 +79,63 @@ public static String getText(final String uri) { String.format("app.datasource.password=%s", oracle23aiContainer.getPassword()), "app.datasource.type=oracle.jdbc.pool.OracleDataSource"); - @SpringBootConfiguration - @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) - public static class TestClient { - - @Value("${test.spring.ai.vectorstore.oracle.distanceType}") - OracleVectorStore.OracleVectorStoreDistanceType distanceType; - - @Value("${test.spring.ai.vectorstore.oracle.searchAccuracy}") - int searchAccuracy; - - @Bean - public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { - return new OracleVectorStore(jdbcTemplate, embeddingModel, OracleVectorStore.DEFAULT_TABLE_NAME, - OracleVectorStore.OracleVectorStoreIndexType.IVF, distanceType, 384, searchAccuracy, true, true, - true); + public static String getText(final String uri) { + try { + return new DefaultResourceLoader().getResource(uri).getContentAsString(StandardCharsets.UTF_8); } - - @Bean - public JdbcTemplate myJdbcTemplate(DataSource dataSource) { - return new JdbcTemplate(dataSource); + catch (IOException e) { + throw new RuntimeException(e); } + } - @Bean - @Primary - @ConfigurationProperties("app.datasource") - public DataSourceProperties dataSourceProperties() { - return new DataSourceProperties(); - } + private static void dropTable(ApplicationContext context, String tableName) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + jdbcTemplate.execute("DROP TABLE IF EXISTS " + tableName + " PURGE"); + } - @Bean - public OracleDataSource dataSource(DataSourceProperties dataSourceProperties) { - return dataSourceProperties.initializeDataSourceBuilder().type(OracleDataSource.class).build(); + private static boolean isSortedByDistance(final List documents) { + final List distances = documents.stream() + .map(doc -> (Double) doc.getMetadata().get("distance")) + .toList(); + + if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + return true; } - @Bean - public EmbeddingModel embeddingModel() { - try { - TransformersEmbeddingModel tem = new TransformersEmbeddingModel(); - tem.afterPropertiesSet(); - return tem; - } - catch (Exception e) { - throw new RuntimeException("Failed initializing embedding model", e); + Iterator iter = distances.iterator(); + Double current; + Double previous = iter.next(); + while (iter.hasNext()) { + current = iter.next(); + if (previous > current) { + return false; } + previous = current; } - - } - - private static void dropTable(ApplicationContext context, String tableName) { - JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); - jdbcTemplate.execute("DROP TABLE IF EXISTS " + tableName + " PURGE"); + return true; } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "DOT", "EUCLIDEAN", "EUCLIDEAN_SQUARED", "MANHATTAN" }) public void addAndSearch(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore .similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); List results2 = vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); @@ -157,7 +149,7 @@ public void addAndSearch(String distanceType) { @CsvSource({ "COSINE,-1", "DOT,-1", "EUCLIDEAN,-1", "EUCLIDEAN_SQUARED,-1", "MANHATTAN,-1", "COSINE,75", "DOT,80", "EUCLIDEAN,60", "EUCLIDEAN_SQUARED,30", "MANHATTAN,42" }) public void searchWithFilters(String distanceType, int searchAccuracy) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + searchAccuracy) .run(context -> { @@ -231,7 +223,7 @@ public void searchWithFilters(String distanceType, int searchAccuracy) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "DOT", "EUCLIDEAN", "EUCLIDEAN_SQUARED", "MANHATTAN" }) public void documentUpdate(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -270,13 +262,13 @@ public void documentUpdate(String distanceType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "DOT" }) public void searchWithThreshold(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThresholdAll()); @@ -296,32 +288,58 @@ public void searchWithThreshold(String distanceType) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(1).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); dropTable(context, ((OracleVectorStore) vectorStore).getTableName()); }); } - private static boolean isSortedByDistance(final List documents) { - final List distances = documents.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + public static class TestClient { - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { - return true; + @Value("${test.spring.ai.vectorstore.oracle.distanceType}") + OracleVectorStore.OracleVectorStoreDistanceType distanceType; + + @Value("${test.spring.ai.vectorstore.oracle.searchAccuracy}") + int searchAccuracy; + + @Bean + public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + return new OracleVectorStore(jdbcTemplate, embeddingModel, OracleVectorStore.DEFAULT_TABLE_NAME, + OracleVectorStore.OracleVectorStoreIndexType.IVF, this.distanceType, 384, this.searchAccuracy, true, + true, true); } - Iterator iter = distances.iterator(); - Double current; - Double previous = iter.next(); - while (iter.hasNext()) { - current = iter.next(); - if (previous > current) { - return false; + @Bean + public JdbcTemplate myJdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public OracleDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(OracleDataSource.class).build(); + } + + @Bean + public EmbeddingModel embeddingModel() { + try { + TransformersEmbeddingModel tem = new TransformersEmbeddingModel(); + tem.afterPropertiesSet(); + return tem; + } + catch (Exception e) { + throw new RuntimeException("Failed initializing embedding model", e); } - previous = current; } - return true; + } } diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java index 8d2636fbf02..50a496e448f 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,7 +23,16 @@ import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import oracle.jdbc.pool.OracleDataSource; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.oracle.OracleContainer; +import org.testcontainers.utility.MountableFile; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -47,15 +55,8 @@ import org.springframework.context.annotation.Primary; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.oracle.OracleContainer; -import org.testcontainers.utility.MountableFile; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import oracle.jdbc.pool.OracleDataSource; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -100,13 +101,13 @@ private static void dropTable(ApplicationContext context, String tableName) { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverterTests.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverterTests.java index 23f35e2c92f..165e0a853c5 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverterTests.java @@ -1,11 +1,28 @@ -package org.springframework.ai.vectorstore; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; +import static org.assertj.core.api.Assertions.assertThat; + public class SqlJsonPathFilterExpressionConverterTests { @Test diff --git a/vector-stores/spring-ai-oracle-store/src/test/resources/initialize.sql b/vector-stores/spring-ai-oracle-store/src/test/resources/initialize.sql index ac38a19652f..0b42b6ff7ea 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/resources/initialize.sql +++ b/vector-stores/spring-ai-oracle-store/src/test/resources/initialize.sql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + -- Exit on any errors WHENEVER SQLERROR EXIT SQL.SQLCODE diff --git a/vector-stores/spring-ai-pgvector-store/pom.xml b/vector-stores/spring-ai-pgvector-store/pom.xml index 2cd142626d2..7f3f5341623 100644 --- a/vector-stores/spring-ai-pgvector-store/pom.xml +++ b/vector-stores/spring-ai-pgvector-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java index 18d8e23bc16..06db63670c7 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.List; + import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; -import java.util.List; /** * Converts {@link Expression} into PgVector metadata filter expression format. diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorSchemaValidator.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorSchemaValidator.java index f8017d97b32..5e7c30d39d9 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorSchemaValidator.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorSchemaValidator.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.ArrayList; @@ -21,6 +22,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; @@ -64,7 +66,7 @@ public boolean isTableExists(String schemaName, String tableName) { String sql = "SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?"; try { // Query for a single integer value, if it exists, table exists - jdbcTemplate.queryForObject(sql, Integer.class, schemaName, tableName); + this.jdbcTemplate.queryForObject(sql, Integer.class, schemaName, tableName); return true; } catch (DataAccessException e) { @@ -100,7 +102,7 @@ void validateTableSchema(String schemaName, String tableName) { // Include the schema name in the query to target the correct table String query = "SELECT column_name, data_type FROM information_schema.columns " + "WHERE table_schema = ? AND table_name = ?"; - List> columns = jdbcTemplate.queryForList(query, + List> columns = this.jdbcTemplate.queryForList(query, new Object[] { schemaName, tableName }); if (columns.isEmpty()) { diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 486902fb8f7..56bb866e61b 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; @@ -23,6 +33,7 @@ import org.postgresql.util.PGobject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -44,15 +55,6 @@ import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; - /** * Uses the "vector_store" table to store the Spring AI vector data. The table and the * vector index will be auto-created if not available. @@ -67,8 +69,6 @@ */ public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean { - private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class); - public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536; public static final int INVALID_EMBEDDING_DIMENSION = -1; @@ -81,10 +81,17 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public static final boolean DEFAULT_SCHEMA_VALIDATION = false; - public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); - public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000; + private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class); + + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + PgDistanceType.COSINE_DISTANCE, VectorStoreSimilarityMetric.COSINE, PgDistanceType.EUCLIDEAN_DISTANCE, + VectorStoreSimilarityMetric.EUCLIDEAN, PgDistanceType.NEGATIVE_INNER_PRODUCT, + VectorStoreSimilarityMetric.DOT); + + public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); + private final String vectorTableName; private final String vectorIndexName; @@ -183,7 +190,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT } public PgDistanceType getDistanceType() { - return distanceType; + return this.distanceType; } @Override @@ -208,6 +215,7 @@ private void insertOrUpdateBatch(List batch) { + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? "; this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + @Override public void setValues(PreparedStatement ps, int i) throws SQLException { @@ -247,7 +255,7 @@ private String toJson(Map map) { public Optional doDelete(List idList) { int updateCount = 0; for (String id : idList) { - int count = jdbcTemplate.update("DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?", + int count = this.jdbcTemplate.update("DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?", UUID.fromString(id)); updateCount = updateCount + count; } @@ -281,6 +289,7 @@ public List embeddingDistance(String query) { return this.jdbcTemplate.query( "SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM " + getFullyQualifiedTableName(), new RowMapper() { + @Override @Nullable public Double mapRow(ResultSet rs, int rowNum) throws SQLException { @@ -383,6 +392,23 @@ int embeddingDimensions() { return OPENAI_EMBEDDING_DIMENSION_SIZE; } + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.PG_VECTOR.value(), operationName) + .withCollectionName(this.vectorTableName) + .withDimensions(this.embeddingDimensions()) + .withNamespace(this.schemaName) + .withSimilarityMetric(getSimilarityMetric()); + } + + private String getSimilarityMetric() { + if (!SIMILARITY_TYPE_MAPPING.containsKey(this.getDistanceType())) { + return this.getDistanceType().name(); + } + return SIMILARITY_TYPE_MAPPING.get(this.distanceType).value(); + } + /** * By default, pgvector performs exact nearest neighbor search, which provides perfect * recall. You can add an index to use approximate nearest neighbor search, which @@ -492,7 +518,7 @@ private Map toMap(PGobject pgObject) { String source = pgObject.getValue(); try { - return (Map) objectMapper.readValue(source, Map.class); + return (Map) this.objectMapper.readValue(source, Map.class); } catch (JsonProcessingException e) { throw new RuntimeException(e); @@ -611,26 +637,4 @@ public PgVectorStore build() { } - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - - return VectorStoreObservationContext.builder(VectorStoreProvider.PG_VECTOR.value(), operationName) - .withCollectionName(this.vectorTableName) - .withDimensions(this.embeddingDimensions()) - .withNamespace(this.schemaName) - .withSimilarityMetric(getSimilarityMetric()); - } - - private static Map SIMILARITY_TYPE_MAPPING = Map.of( - PgDistanceType.COSINE_DISTANCE, VectorStoreSimilarityMetric.COSINE, PgDistanceType.EUCLIDEAN_DISTANCE, - VectorStoreSimilarityMetric.EUCLIDEAN, PgDistanceType.NEGATIVE_INNER_PRODUCT, - VectorStoreSimilarityMetric.DOT); - - private String getSimilarityMetric() { - if (!SIMILARITY_TYPE_MAPPING.containsKey(this.getDistanceType())) { - return this.getDistanceType().name(); - } - return SIMILARITY_TYPE_MAPPING.get(this.distanceType).value(); - } - -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java index 4fa2c56a5ed..efef6d9135e 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.junit.jupiter.api.Test; @@ -46,32 +47,32 @@ public void explicitlySetDimensions() { final int explicitDimensions = 696; - var dim = new PgVectorStore(jdbcTemplate, embeddingModel, explicitDimensions).embeddingDimensions(); + var dim = new PgVectorStore(this.jdbcTemplate, this.embeddingModel, explicitDimensions).embeddingDimensions(); assertThat(dim).isEqualTo(explicitDimensions); - verify(embeddingModel, never()).dimensions(); + verify(this.embeddingModel, never()).dimensions(); } @Test public void embeddingModelDimensions() { - when(embeddingModel.dimensions()).thenReturn(969); + when(this.embeddingModel.dimensions()).thenReturn(969); - var dim = new PgVectorStore(jdbcTemplate, embeddingModel).embeddingDimensions(); + var dim = new PgVectorStore(this.jdbcTemplate, this.embeddingModel).embeddingDimensions(); assertThat(dim).isEqualTo(969); - verify(embeddingModel, only()).dimensions(); + verify(this.embeddingModel, only()).dimensions(); } @Test public void fallBackToDefaultDimensions() { - when(embeddingModel.dimensions()).thenThrow(new RuntimeException()); + when(this.embeddingModel.dimensions()).thenThrow(new RuntimeException()); - var dim = new PgVectorStore(jdbcTemplate, embeddingModel).embeddingDimensions(); + var dim = new PgVectorStore(this.jdbcTemplate, this.embeddingModel).embeddingDimensions(); assertThat(dim).isEqualTo(PgVectorStore.OPENAI_EMBEDDING_DIMENSION_SIZE); - verify(embeddingModel, only()).dimensions(); + verify(this.embeddingModel, only()).dimensions(); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java index ca662ab8e09..b2ef770cca8 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; -import static org.assertj.core.api.Assertions.assertThat; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; + +import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; @@ -28,10 +35,6 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; -import org.springframework.ai.vectorstore.filter.Filter.Group; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.Filter.Value; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; /** * @author Muthukumaran Navaneethakrishnan @@ -44,14 +47,14 @@ public class PgVectorFilterExpressionConverterTests { @Test public void testEQ() { // country == "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("$.country == \"BG\""); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr).isEqualTo("$.genre == \"drama\" && $.year >= 2020"); @@ -60,7 +63,7 @@ public void tesEqAndGte() { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr) .isEqualTo("($.genre == \"comedy\" || $.genre == \"documentary\" || $.genre == \"drama\")"); @@ -69,7 +72,7 @@ public void tesIn() { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -79,7 +82,7 @@ public void testNe() { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); @@ -90,7 +93,7 @@ public void testGroup() { @Test public void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -102,7 +105,7 @@ public void tesBoolean() { @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -111,7 +114,7 @@ public void testDecimal() { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("$.\"country 1 2 3\" == \"BG\""); } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java index 5e4204cdbbc..0df031e63cd 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreCustomNamesIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreCustomNamesIT.java index 5d34e98c6b4..626085dbaee 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreCustomNamesIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreCustomNamesIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Random; + +import javax.sql.DataSource; + import com.zaxxer.hikari.HikariDataSource; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.api.OpenAiApi; @@ -32,12 +41,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import javax.sql.DataSource; -import java.util.Random; import static org.assertj.core.api.Assertions.assertThat; @@ -92,19 +95,20 @@ private static boolean isSchemaExists(ApplicationContext context, String schemaN @Test public void shouldCreateDefaultTableAndIndexIfNotPresentInConfig() { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.schemaValidation=false").run(context -> { - assertThat(context).hasNotFailed(); - assertThat(isTableExists(context, "vector_store")).isTrue(); - assertThat(isSchemaExists(context, "public")).isTrue(); - dropTableByName(context, "vector_store"); + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.schemaValidation=false") + .run(context -> { + assertThat(context).hasNotFailed(); + assertThat(isTableExists(context, "vector_store")).isTrue(); + assertThat(isSchemaExists(context, "public")).isTrue(); + dropTableByName(context, "vector_store"); - }); + }); } @Test public void shouldCreateTableAndIndexIfNotPresentInDatabase() { String tableName = "new_vector_table"; - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName) .run(context -> { assertThat(isTableExists(context, tableName)).isTrue(); assertThat(isIndexExists(context, "public", tableName, tableName + "_index")).isTrue(); @@ -118,7 +122,7 @@ public void shouldFailWhenCustomTableIsAbsentAndValidationEnabled() { String tableName = "customvectortable"; - contextRunner + this.contextRunner .withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName, "test.spring.ai.vectorstore.pgvector.schemaValidation=true") @@ -136,7 +140,7 @@ public void shouldFailOnSQLInjectionAttemptInTableName() { String tableName = "users; DROP TABLE users;"; - contextRunner + this.contextRunner .withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName, "test.spring.ai.vectorstore.pgvector.schemaValidation=true") @@ -156,7 +160,7 @@ public void shouldFailOnSQLInjectionAttemptInSchemaName() { String schemaName = "public; DROP TABLE users;"; String tableName = "customvectortable"; - contextRunner + this.contextRunner .withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName, "test.spring.ai.vectorstore.pgvector.schemaName=" + schemaName, "test.spring.ai.vectorstore.pgvector.schemaValidation=true") @@ -189,10 +193,10 @@ public static class TestApplication { @Bean public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { - return new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withSchemaName(schemaName) - .withVectorTableName(vectorTableName) - .withVectorTableValidationsEnabled(schemaValidation) - .withDimensions(dimensions) + return new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withSchemaName(this.schemaName) + .withVectorTableName(this.vectorTableName) + .withVectorTableValidationsEnabled(this.schemaValidation) + .withDimensions(this.dimensions) .withDistanceType(PgVectorStore.PgDistanceType.COSINE_DISTANCE) .withRemoveExistingVectorStoreTable(true) .withIndexType(PgIndexType.HNSW) diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java index f18a1669fa6..8405d32337f 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -28,12 +27,17 @@ import javax.sql.DataSource; +import com.zaxxer.hikari.HikariDataSource; import org.junit.Assert; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -53,11 +57,8 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.util.CollectionUtils; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.zaxxer.hikari.HikariDataSource; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Muthukumaran Navaneethakrishnan @@ -74,6 +75,16 @@ public class PgVectorStoreIT { .withUsername("postgres") .withPassword("postgres"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", + + // JdbcTemplate configuration + String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), "postgres"), + "app.datasource.username=postgres", "app.datasource.password=postgres", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -89,41 +100,59 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class) - .withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", - - // JdbcTemplate configuration - String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), "postgres"), - "app.datasource.username=postgres", "app.datasource.password=postgres", - "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); - private static void dropTable(ApplicationContext context) { JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store"); } + static Stream provideFilters() { + return Stream.of(Arguments.of("country in ['BG','NL']", 3), // String Filters In + Arguments.of("year in [2020]", 1), // Numeric Filters In + Arguments.of("country not in ['BG']", 1), // String Filter Not In + Arguments.of("year not in [2020]", 2) // Numeric Filter Not In + ); + } + + private static boolean isSortedByDistance(List docs) { + + List distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + + if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + return true; + } + + Iterator iter = distances.iterator(); + Float current, previous = iter.next(); + while (iter.hasNext()) { + current = iter.next(); + if (previous > current) { + return false; + } + previous = current; + } + return true; + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" }) public void addAndSearch(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore .similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); List results2 = vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); @@ -133,19 +162,11 @@ public void addAndSearch(String distanceType) { }); } - static Stream provideFilters() { - return Stream.of(Arguments.of("country in ['BG','NL']", 3), // String Filters In - Arguments.of("year in [2020]", 1), // Numeric Filters In - Arguments.of("country not in ['BG']", 1), // String Filter Not In - Arguments.of("year not in [2020]", 2) // Numeric Filter Not In - ); - } - @ParameterizedTest(name = "Filter expression {0} should return {1} records ") @MethodSource("provideFilters") public void searchWithInFilter(String expression, Integer expectedRecords) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE") + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE") .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -177,7 +198,7 @@ public void searchWithInFilter(String expression, Integer expectedRecords) { @ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" }) public void searchWithFilters(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -251,7 +272,7 @@ public void searchWithFilters(String distanceType) { @ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" }) public void documentUpdate(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -292,12 +313,12 @@ public void documentUpdate(String distanceType) { // @ValueSource(strings = { "COSINE_DISTANCE" }) public void searchWithThreshold(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThresholdAll()); @@ -317,32 +338,12 @@ public void searchWithThreshold(String distanceType) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(1).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); dropTable(context); }); } - private static boolean isSortedByDistance(List docs) { - - List distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); - - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { - return true; - } - - Iterator iter = distances.iterator(); - Float current, previous = iter.next(); - while (iter.hasNext()) { - current = iter.next(); - if (previous > current) { - return false; - } - previous = current; - } - return true; - } - @SpringBootConfiguration @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication { @@ -353,7 +354,7 @@ public static class TestApplication { @Bean public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { return new PgVectorStore(jdbcTemplate, embeddingModel, PgVectorStore.INVALID_EMBEDDING_DIMENSION, - distanceType, true, PgIndexType.HNSW, true); + this.distanceType, true, PgIndexType.HNSW, true); } @Bean diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreObservationIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreObservationIT.java index f60b66b4940..963b256e915 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,8 +23,16 @@ import javax.sql.DataSource; +import com.zaxxer.hikari.HikariDataSource; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.SpringAiKind; @@ -48,15 +55,8 @@ import org.springframework.context.annotation.Primary; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import com.zaxxer.hikari.HikariDataSource; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instruAbstractObservationVectorStorementation in @@ -75,6 +75,16 @@ public class PgVectorStoreObservationIT { .withUsername("postgres") .withPassword("postgres"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class) + .withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", + + // JdbcTemplate configuration + String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), "postgres"), + "app.datasource.username=postgres", "app.datasource.password=postgres", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -90,26 +100,16 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class) - .withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", - - // JdbcTemplate configuration - String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), "postgres"), - "app.datasource.username=postgres", "app.datasource.password=postgres", - "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); - @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java index 488dbd3f73e..993f8856ead 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Collections; + import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.mockito.ArgumentCaptor; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -29,13 +37,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import java.util.Collections; - -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.jdbc.core.BatchPreparedStatementSetter; -import org.springframework.jdbc.core.JdbcTemplate; - /** * @author Muthukumaran Navaneethakrishnan * @author Soby Chacko diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java index 5c151753d6a..abde63cfe84 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,6 @@ package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import java.util.List; import java.util.Map; @@ -32,6 +26,10 @@ import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import org.postgresql.ds.PGSimpleDataSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor; import org.springframework.ai.chat.messages.AssistantMessage; @@ -43,9 +41,12 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * @author Fabian Krüger @@ -55,38 +56,13 @@ @Testcontainers class PgVectorStoreWithChatMemoryAdvisorIT { - float[] embed = { 0.003961659F, -0.0073295482F, 0.02663665F }; - @Container @SuppressWarnings("resource") static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(PgVectorImage.DEFAULT_IMAGE) .withUsername("postgres") .withPassword("postgres"); - /** - * Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar - * messages from the (gp)vector store. - */ - @Test - @DisplayName("Advised chat should have similar messages from vector store") - void advisedChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { - // faked ChatModel - ChatModel chatModel = chatModelAlwaysReturnsTheSameReply(); - // faked embedding model - EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed(); - PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel); - - // do the chat - ChatClient.builder(chatModel) - .build() - .prompt() - .user("joke") - .advisors(new VectorStoreChatMemoryAdvisor(store)) - .call() - .chatResponse(); - - verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel); - } + float[] embed = { 0.003961659F, -0.0073295482F, 0.02663665F }; private static @NotNull ChatModel chatModelAlwaysReturnsTheSameReply() { ChatModel chatModel = mock(ChatModel.class); @@ -109,7 +85,7 @@ private static void initStore(PgVectorStore store) throws Exception { private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingModel embeddingModel) throws Exception { JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer(); PgVectorStore vectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withDimensions(3) // match - // embeddings + // embeddings .withInitializeSchema(true) .build(); initStore(vectorStore); @@ -124,20 +100,6 @@ private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingMode return new JdbcTemplate(ds); } - @SuppressWarnings("unchecked") - private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() { - EmbeddingModel embeddingModel = mock(EmbeddingModel.class); - - Mockito.doAnswer(invocationOnMock -> { - Object[] arguments = invocationOnMock.getArguments(); - List documents = (List) arguments[0]; - documents.forEach(d -> d.setEmbedding(embed)); - return List.of(embed, embed); - }).when(embeddingModel).embed(ArgumentMatchers.any(), any(), any()); - when(embeddingModel.embed(any(String.class))).thenReturn(embed); - return embeddingModel; - } - private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatModel chatModel) { ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); verify(chatModel).call(promptCaptor.capture()); @@ -156,4 +118,43 @@ private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatM """); } -} \ No newline at end of file + /** + * Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar + * messages from the (gp)vector store. + */ + @Test + @DisplayName("Advised chat should have similar messages from vector store") + void advisedChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { + // faked ChatModel + ChatModel chatModel = chatModelAlwaysReturnsTheSameReply(); + // faked embedding model + EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed(); + PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel); + + // do the chat + ChatClient.builder(chatModel) + .build() + .prompt() + .user("joke") + .advisors(new VectorStoreChatMemoryAdvisor(store)) + .call() + .chatResponse(); + + verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel); + } + + @SuppressWarnings("unchecked") + private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() { + EmbeddingModel embeddingModel = mock(EmbeddingModel.class); + + Mockito.doAnswer(invocationOnMock -> { + Object[] arguments = invocationOnMock.getArguments(); + List documents = (List) arguments[0]; + documents.forEach(d -> d.setEmbedding(this.embed)); + return List.of(this.embed, this.embed); + }).when(embeddingModel).embed(ArgumentMatchers.any(), any(), any()); + when(embeddingModel.embed(any(String.class))).thenReturn(this.embed); + return embeddingModel; + } + +} diff --git a/vector-stores/spring-ai-pinecone-store/pom.xml b/vector-stores/spring-ai-pinecone-store/pom.xml index b1c883453fb..87b2722c561 100644 --- a/vector-stores/spring-ai-pinecone-store/pom.xml +++ b/vector-stores/spring-ai-pinecone-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java index 59a7f7ff8e1..f4243653b07 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -21,27 +21,11 @@ import java.util.Map; import java.util.Optional; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.model.EmbeddingUtils; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.filter.converter.PineconeFilterExpressionConverter; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import org.springframework.util.Assert; -import org.springframework.util.StringUtils; - import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.protobuf.Struct; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; - import io.micrometer.observation.ObservationRegistry; import io.pinecone.PineconeClient; import io.pinecone.PineconeClientConfig; @@ -53,6 +37,21 @@ import io.pinecone.proto.UpsertRequest; import io.pinecone.proto.Vector; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.filter.converter.PineconeFilterExpressionConverter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + /** * A VectorStore implementation backed by Pinecone, a cloud-based vector database. This * store supports creating, updating, deleting, and similarity searching of documents in a @@ -86,180 +85,6 @@ public class PineconeVectorStore extends AbstractObservationVectorStore { private final BatchingStrategy batchingStrategy; - /** - * Configuration class for the PineconeVectorStore. - */ - public static final class PineconeVectorStoreConfig { - - // The free tier (gcp-starter) doesn't support Namespaces. - // Leave the namespace empty (e.g. "") for the free tier. - private final String namespace; - - private final String contentFieldName; - - private final String distanceMetadataFieldName; - - private final PineconeConnectionConfig connectionConfig; - - private final PineconeClientConfig clientConfig; - - // private final int defaultSimilarityTopK; - - /** - * Constructor using the builder. - * @param builder The configuration builder. - */ - /** - * Constructor using the builder. - * @param builder The configuration builder. - */ - public PineconeVectorStoreConfig(Builder builder) { - this.namespace = builder.namespace; - this.contentFieldName = builder.contentFieldName; - this.distanceMetadataFieldName = builder.distanceMetadataFieldName; - - // this.defaultSimilarityTopK = builder.defaultSimilarityTopK; - this.connectionConfig = new PineconeConnectionConfig().withIndexName(builder.indexName); - this.clientConfig = new PineconeClientConfig().withApiKey(builder.apiKey) - .withEnvironment(builder.environment) - .withProjectName(builder.projectId) - .withApiKey(builder.apiKey) - .withServerSideTimeoutSec((int) builder.serverSideTimeout.toSeconds()); - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - return new Builder(); - } - - /** - * {@return the default config} - */ - public static PineconeVectorStoreConfig defaultConfig() { - return builder().build(); - } - - public static class Builder { - - private String apiKey; - - private String projectId; - - private String environment; - - private String indexName; - - // The free-tier (gcp-starter) doesn't support Namespaces! - private String namespace = ""; - - private String contentFieldName = CONTENT_FIELD_NAME; - - private String distanceMetadataFieldName = DISTANCE_METADATA_FIELD_NAME; - - /** - * Optional server-side timeout in seconds for all operations. Default: 20 - * seconds. - */ - private Duration serverSideTimeout = Duration.ofSeconds(20); - - private Builder() { - } - - /** - * Pinecone api key. - * @param apiKey key to use. - * @return this builder. - */ - public Builder withApiKey(String apiKey) { - this.apiKey = apiKey; - return this; - } - - /** - * Pinecone project id. - * @param projectId Project id to use. - * @return this builder. - */ - public Builder withProjectId(String projectId) { - this.projectId = projectId; - return this; - } - - /** - * Pinecone environment name. - * @param environment Environment name (e.g. gcp-starter). - * @return this builder. - */ - public Builder withEnvironment(String environment) { - this.environment = environment; - return this; - } - - /** - * Pinecone index name. - * @param indexName Pinecone index name to use. - * @return this builder. - */ - public Builder withIndexName(String indexName) { - this.indexName = indexName; - return this; - } - - /** - * Pinecone Namespace. The free-tier (gcp-starter) doesn't support Namespaces. - * For free-tier leave the namespace empty. - * @param namespace Pinecone namespace to use. - * @return this builder. - */ - public Builder withNamespace(String namespace) { - this.namespace = namespace; - return this; - } - - /** - * Content field name. - * @param contentFieldName content field name to use. - * @return this builder. - */ - public Builder withContentFieldName(String contentFieldName) { - this.contentFieldName = contentFieldName; - return this; - } - - /** - * Distance metadata field name. - * @param distanceMetadataFieldName distance metadata field name to use. - * @return this builder. - */ - public Builder withDistanceMetadataFieldName(String distanceMetadataFieldName) { - this.distanceMetadataFieldName = distanceMetadataFieldName; - return this; - } - - /** - * Pinecone server side timeout. - * @param serverSideTimeout server timeout to use. - * @return this builder. - */ - public Builder withServerSideTimeout(Duration serverSideTimeout) { - this.serverSideTimeout = serverSideTimeout; - return this; - } - - /** - * {@return the immutable configuration} - */ - public PineconeVectorStoreConfig build() { - return new PineconeVectorStoreConfig(this); - } - - } - - } - /** * Constructs a new PineconeVectorStore. * @param config The configuration for the store. @@ -442,6 +267,7 @@ private Map extractMetadata(Struct metadataStruct) { try { String json = JsonFormat.printer().print(metadataStruct); Map metadata = this.objectMapper.readValue(json, new TypeReference>() { + }); metadata.remove(this.pineconeContentFieldName); return metadata; @@ -461,4 +287,178 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str .withFieldName(this.pineconeContentFieldName); } + /** + * Configuration class for the PineconeVectorStore. + */ + public static final class PineconeVectorStoreConfig { + + // The free tier (gcp-starter) doesn't support Namespaces. + // Leave the namespace empty (e.g. "") for the free tier. + private final String namespace; + + private final String contentFieldName; + + private final String distanceMetadataFieldName; + + private final PineconeConnectionConfig connectionConfig; + + private final PineconeClientConfig clientConfig; + + // private final int defaultSimilarityTopK; + + /** + * Constructor using the builder. + * @param builder The configuration builder. + */ + /** + * Constructor using the builder. + * @param builder The configuration builder. + */ + public PineconeVectorStoreConfig(Builder builder) { + this.namespace = builder.namespace; + this.contentFieldName = builder.contentFieldName; + this.distanceMetadataFieldName = builder.distanceMetadataFieldName; + + // this.defaultSimilarityTopK = builder.defaultSimilarityTopK; + this.connectionConfig = new PineconeConnectionConfig().withIndexName(builder.indexName); + this.clientConfig = new PineconeClientConfig().withApiKey(builder.apiKey) + .withEnvironment(builder.environment) + .withProjectName(builder.projectId) + .withApiKey(builder.apiKey) + .withServerSideTimeoutSec((int) builder.serverSideTimeout.toSeconds()); + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * {@return the default config} + */ + public static PineconeVectorStoreConfig defaultConfig() { + return builder().build(); + } + + public static class Builder { + + private String apiKey; + + private String projectId; + + private String environment; + + private String indexName; + + // The free-tier (gcp-starter) doesn't support Namespaces! + private String namespace = ""; + + private String contentFieldName = CONTENT_FIELD_NAME; + + private String distanceMetadataFieldName = DISTANCE_METADATA_FIELD_NAME; + + /** + * Optional server-side timeout in seconds for all operations. Default: 20 + * seconds. + */ + private Duration serverSideTimeout = Duration.ofSeconds(20); + + private Builder() { + } + + /** + * Pinecone api key. + * @param apiKey key to use. + * @return this builder. + */ + public Builder withApiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Pinecone project id. + * @param projectId Project id to use. + * @return this builder. + */ + public Builder withProjectId(String projectId) { + this.projectId = projectId; + return this; + } + + /** + * Pinecone environment name. + * @param environment Environment name (e.g. gcp-starter). + * @return this builder. + */ + public Builder withEnvironment(String environment) { + this.environment = environment; + return this; + } + + /** + * Pinecone index name. + * @param indexName Pinecone index name to use. + * @return this builder. + */ + public Builder withIndexName(String indexName) { + this.indexName = indexName; + return this; + } + + /** + * Pinecone Namespace. The free-tier (gcp-starter) doesn't support Namespaces. + * For free-tier leave the namespace empty. + * @param namespace Pinecone namespace to use. + * @return this builder. + */ + public Builder withNamespace(String namespace) { + this.namespace = namespace; + return this; + } + + /** + * Content field name. + * @param contentFieldName content field name to use. + * @return this builder. + */ + public Builder withContentFieldName(String contentFieldName) { + this.contentFieldName = contentFieldName; + return this; + } + + /** + * Distance metadata field name. + * @param distanceMetadataFieldName distance metadata field name to use. + * @return this builder. + */ + public Builder withDistanceMetadataFieldName(String distanceMetadataFieldName) { + this.distanceMetadataFieldName = distanceMetadataFieldName; + return this; + } + + /** + * Pinecone server side timeout. + * @param serverSideTimeout server timeout to use. + * @return this builder. + */ + public Builder withServerSideTimeout(Duration serverSideTimeout) { + this.serverSideTimeout = serverSideTimeout; + return this; + } + + /** + * {@return the immutable configuration} + */ + public PineconeVectorStoreConfig build() { + return new PineconeVectorStoreConfig(this); + } + + } + + } + } diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStoreHints.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStoreHints.java index c9b63fbb471..20c0dd75df5 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStoreHints.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStoreHints.java @@ -1,11 +1,27 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; +import java.util.Set; + import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; -import java.util.Set; - /** * Registration of AOT hints for Pinecone's vector store. * diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java index a0f8a257cd9..917abaf4745 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -62,6 +63,9 @@ public class PineconeVectorStoreIT { private static final String CUSTOM_CONTENT_FIELD_NAME = "article"; + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -77,9 +81,6 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -90,11 +91,11 @@ public static void beforeAll() { @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); @@ -104,14 +105,14 @@ public void addAndSearchTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); @@ -125,7 +126,7 @@ public void addAndSearchWithFilters() { // Pinecone metadata filtering syntax: // https://docs.pinecone.io/docs/metadata-filtering - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -177,7 +178,7 @@ public void addAndSearchWithFilters() { public void documentUpdateTest() { // Note ,using OpenAI to calculate embeddings - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -234,11 +235,11 @@ public void documentUpdateTest() { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await().until(() -> { return vectorStore @@ -259,13 +260,13 @@ public void searchThresholdTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); }, hasSize(0)); @@ -301,4 +302,4 @@ public TransformersEmbeddingModel embeddingModel() { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java index 8fc6b41f774..b8a84de02c4 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,11 +22,15 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.awaitility.Awaitility; import org.awaitility.Duration; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -45,9 +47,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; /** * @author Christian Tzolov @@ -67,6 +68,9 @@ public class PineconeVectorStoreObservationIT { private static final String CUSTOM_CONTENT_FIELD_NAME = "article"; + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -82,9 +86,6 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -95,13 +96,13 @@ public static void beforeAll() { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() @@ -165,7 +166,7 @@ void observationVectorStoreAddAndQueryOperations() { .hasBeenStopped(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); diff --git a/vector-stores/spring-ai-qdrant-store/pom.xml b/vector-stores/spring-ai-qdrant-store/pom.xml index f88c4f84e5e..2d9c598e529 100644 --- a/vector-stores/spring-ai-qdrant-store/pom.xml +++ b/vector-stores/spring-ai-qdrant-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java index ec1d1fbdd97..1870a37aabe 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.qdrant; -import static io.qdrant.client.ConditionFactory.filter; -import static io.qdrant.client.ConditionFactory.match; -import static io.qdrant.client.ConditionFactory.matchExceptKeywords; -import static io.qdrant.client.ConditionFactory.matchExceptValues; -import static io.qdrant.client.ConditionFactory.matchKeyword; -import static io.qdrant.client.ConditionFactory.matchKeywords; -import static io.qdrant.client.ConditionFactory.matchValues; -import static io.qdrant.client.ConditionFactory.range; +package org.springframework.ai.vectorstore.qdrant; import java.util.ArrayList; import java.util.List; +import io.qdrant.client.grpc.Points.Condition; +import io.qdrant.client.grpc.Points.Filter; +import io.qdrant.client.grpc.Points.Range; + import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; @@ -34,9 +30,14 @@ import org.springframework.ai.vectorstore.filter.Filter.Operand; import org.springframework.ai.vectorstore.filter.Filter.Value; -import io.qdrant.client.grpc.Points.Condition; -import io.qdrant.client.grpc.Points.Filter; -import io.qdrant.client.grpc.Points.Range; +import static io.qdrant.client.ConditionFactory.filter; +import static io.qdrant.client.ConditionFactory.match; +import static io.qdrant.client.ConditionFactory.matchExceptKeywords; +import static io.qdrant.client.ConditionFactory.matchExceptValues; +import static io.qdrant.client.ConditionFactory.matchKeyword; +import static io.qdrant.client.ConditionFactory.matchKeywords; +import static io.qdrant.client.ConditionFactory.matchValues; +import static io.qdrant.client.ConditionFactory.range; /** * @author Anush Shetty diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java index 8f86cb87206..00ad1e518cc 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import java.util.Map; diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java index 87c1067dd61..13862abc068 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import java.lang.reflect.Array; diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java index 158caf90352..8fadc237ac6 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,17 +16,24 @@ package org.springframework.ai.vectorstore.qdrant; -import static io.qdrant.client.PointIdFactory.id; -import static io.qdrant.client.ValueFactory.value; -import static io.qdrant.client.VectorsFactory.vectors; -import static io.qdrant.client.WithPayloadSelectorFactory.enable; - import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ExecutionException; +import io.micrometer.observation.ObservationRegistry; +import io.qdrant.client.QdrantClient; +import io.qdrant.client.grpc.Collections.Distance; +import io.qdrant.client.grpc.Collections.VectorParams; +import io.qdrant.client.grpc.JsonWithInt.Value; +import io.qdrant.client.grpc.Points.Filter; +import io.qdrant.client.grpc.Points.PointId; +import io.qdrant.client.grpc.Points.PointStruct; +import io.qdrant.client.grpc.Points.ScoredPoint; +import io.qdrant.client.grpc.Points.SearchPoints; +import io.qdrant.client.grpc.Points.UpdateStatus; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -34,7 +41,6 @@ import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; @@ -42,17 +48,10 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import io.micrometer.observation.ObservationRegistry; -import io.qdrant.client.QdrantClient; -import io.qdrant.client.grpc.Collections.Distance; -import io.qdrant.client.grpc.Collections.VectorParams; -import io.qdrant.client.grpc.JsonWithInt.Value; -import io.qdrant.client.grpc.Points.Filter; -import io.qdrant.client.grpc.Points.PointId; -import io.qdrant.client.grpc.Points.PointStruct; -import io.qdrant.client.grpc.Points.ScoredPoint; -import io.qdrant.client.grpc.Points.SearchPoints; -import io.qdrant.client.grpc.Points.UpdateStatus; +import static io.qdrant.client.PointIdFactory.id; +import static io.qdrant.client.ValueFactory.value; +import static io.qdrant.client.VectorsFactory.vectors; +import static io.qdrant.client.WithPayloadSelectorFactory.enable; /** * Qdrant vectorStore implementation. This store supports creating, updating, deleting, @@ -67,12 +66,12 @@ */ public class QdrantVectorStore extends AbstractObservationVectorStore implements InitializingBean { + public static final String DEFAULT_COLLECTION_NAME = "vector_store"; + private static final String CONTENT_FIELD_NAME = "doc_content"; private static final String DISTANCE_FIELD_NAME = "distance"; - public static final String DEFAULT_COLLECTION_NAME = "vector_store"; - private final EmbeddingModel embeddingModel; private final QdrantClient qdrantClient; @@ -85,68 +84,6 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements private final BatchingStrategy batchingStrategy; - /** - * Configuration class for the QdrantVectorStore. - * - * @deprecated since 1.0.0 in favor of {@link QdrantVectorStore}. - */ - @Deprecated(since = "1.0.0", forRemoval = true) - public static final class QdrantVectorStoreConfig { - - private final String collectionName; - - /* - * Constructor using the builder. - * - * @param builder The configuration builder. - */ - - private QdrantVectorStoreConfig(Builder builder) { - this.collectionName = builder.collectionName; - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - return new Builder(); - } - - /** - * {@return the default config} - */ - public static QdrantVectorStoreConfig defaultConfig() { - return builder().build(); - } - - public static class Builder { - - private String collectionName; - - private Builder() { - } - - /** - * @param collectionName REQUIRED. The name of the collection. - */ - public Builder withCollectionName(String collectionName) { - this.collectionName = collectionName; - return this; - } - - /** - * {@return the immutable configuration} - */ - public QdrantVectorStoreConfig build() { - Assert.notNull(collectionName, "collectionName cannot be null"); - return new QdrantVectorStoreConfig(this); - } - - } - - } - /** * Constructs a new QdrantVectorStore. * @param config The configuration for the store. @@ -319,8 +256,9 @@ private Map toPayload(Document document) { @Override public void afterPropertiesSet() throws Exception { - if (!this.initializeSchema) + if (!this.initializeSchema) { return; + } // Create the collection if it does not exist. if (!isCollectionExists()) { @@ -350,4 +288,66 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str } -} \ No newline at end of file + /** + * Configuration class for the QdrantVectorStore. + * + * @deprecated since 1.0.0 in favor of {@link QdrantVectorStore}. + */ + @Deprecated(since = "1.0.0", forRemoval = true) + public static final class QdrantVectorStoreConfig { + + private final String collectionName; + + /* + * Constructor using the builder. + * + * @param builder The configuration builder. + */ + + private QdrantVectorStoreConfig(Builder builder) { + this.collectionName = builder.collectionName; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * {@return the default config} + */ + public static QdrantVectorStoreConfig defaultConfig() { + return builder().build(); + } + + public static class Builder { + + private String collectionName; + + private Builder() { + } + + /** + * @param collectionName REQUIRED. The name of the collection. + */ + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * {@return the immutable configuration} + */ + public QdrantVectorStoreConfig build() { + Assert.notNull(this.collectionName, "collectionName cannot be null"); + return new QdrantVectorStoreConfig(this); + } + + } + + } + +} diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java index 4a6592ffad0..2045be309f5 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java index 7cd00a4eff3..b5e30f9454d 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import java.util.Collections; @@ -28,14 +29,14 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.mistralai.MistralAiEmbeddingModel; -import org.springframework.ai.mistralai.api.MistralAiApi; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.qdrant.QdrantContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.mistralai.MistralAiEmbeddingModel; +import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.boot.SpringBootConfiguration; @@ -62,6 +63,10 @@ public class QdrantVectorStoreIT { @Container static QdrantContainer qdrantContainer = new QdrantContainer(QdrantImage.DEFAULT_IMAGE); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); + List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")), @@ -70,10 +75,6 @@ public class QdrantVectorStoreIT { "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression", Collections.singletonMap("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class) - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); - @BeforeAll static void setup() throws InterruptedException, ExecutionException { @@ -91,23 +92,23 @@ static void setup() throws InterruptedException, ExecutionException { @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); List results2 = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results2).hasSize(0); @@ -117,7 +118,7 @@ public void addAndSearch() { @Test public void addAndSearchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -166,7 +167,7 @@ public void addAndSearchWithFilters() { @Test public void documentUpdateTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -206,11 +207,11 @@ public void documentUpdateTest() { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); var request = SearchRequest.query("Great").withTopK(5); List fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll()); @@ -225,14 +226,14 @@ public void searchThresholdTest() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); }); } diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java index 5d181cf03d4..42031c3c5b2 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.qdrant; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore.qdrant; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -23,9 +22,20 @@ import java.util.Map; import java.util.concurrent.ExecutionException; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import io.qdrant.client.QdrantClient; +import io.qdrant.client.QdrantGrpcClient; +import io.qdrant.client.grpc.Collections.Distance; +import io.qdrant.client.grpc.Collections.VectorParams; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -43,17 +53,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.qdrant.QdrantContainer; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import io.qdrant.client.QdrantClient; -import io.qdrant.client.QdrantGrpcClient; -import io.qdrant.client.grpc.Collections.Distance; -import io.qdrant.client.grpc.Collections.VectorParams; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -70,6 +71,9 @@ public class QdrantVectorStoreObservationIT { @Container static QdrantContainer qdrantContainer = new QdrantContainer(QdrantImage.DEFAULT_IMAGE); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -85,9 +89,6 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @BeforeAll static void setup() throws InterruptedException, ExecutionException { @@ -106,13 +107,13 @@ static void setup() throws InterruptedException, ExecutionException { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-redis-store/pom.xml b/vector-stores/spring-ai-redis-store/pom.xml index 6d8a0ca9a25..cf0c0e4d2ed 100644 --- a/vector-stores/spring-ai-redis-store/pom.xml +++ b/vector-stores/spring-ai-redis-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java index 95198ff3b94..86638c071ac 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.text.MessageFormat; @@ -177,6 +178,7 @@ private NumericBoundary exclusive(Value value) { } static record Numeric(NumericBoundary lower, NumericBoundary upper) { + } static record NumericBoundary(Object value, boolean exclusive) { @@ -197,11 +199,11 @@ public String toString() { if (this == POSITIVE_INFINITY) { return INFINITY; } - return String.format(formatString(), value); + return String.format(formatString(), this.value); } private String formatString() { - if (exclusive) { + if (this.exclusive) { return EXCLUSIVE_FORMAT; } return INCLUSIVE_FORMAT; diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java index d6c7c6808d8..3a850f1dc88 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -27,24 +27,9 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; - -import io.micrometer.observation.ObservationRegistry; import redis.clients.jedis.JedisPooled; import redis.clients.jedis.Pipeline; import redis.clients.jedis.json.Path2; @@ -61,6 +46,21 @@ import redis.clients.jedis.search.schemafields.VectorField; import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + /** * The RedisVectorStore is for managing and querying vector data in a Redis database. It * offers functionalities like adding, deleting, and performing similarity searches on @@ -87,165 +87,6 @@ */ public class RedisVectorStore extends AbstractObservationVectorStore implements InitializingBean { - public enum Algorithm { - - FLAT, HSNW - - } - - public record MetadataField(String name, FieldType fieldType) { - - public static MetadataField text(String name) { - return new MetadataField(name, FieldType.TEXT); - } - - public static MetadataField numeric(String name) { - return new MetadataField(name, FieldType.NUMERIC); - } - - public static MetadataField tag(String name) { - return new MetadataField(name, FieldType.TAG); - } - - } - - /** - * Configuration for the Redis vector store. - */ - public static final class RedisVectorStoreConfig { - - private final String indexName; - - private final String prefix; - - private final String contentFieldName; - - private final String embeddingFieldName; - - private final Algorithm vectorAlgorithm; - - private final List metadataFields; - - private RedisVectorStoreConfig() { - this(builder()); - } - - private RedisVectorStoreConfig(Builder builder) { - this.indexName = builder.indexName; - this.prefix = builder.prefix; - this.contentFieldName = builder.contentFieldName; - this.embeddingFieldName = builder.embeddingFieldName; - this.vectorAlgorithm = builder.vectorAlgorithm; - this.metadataFields = builder.metadataFields; - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - - return new Builder(); - } - - /** - * {@return the default config} - */ - public static RedisVectorStoreConfig defaultConfig() { - - return builder().build(); - } - - public static class Builder { - - private String indexName = DEFAULT_INDEX_NAME; - - private String prefix = DEFAULT_PREFIX; - - private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME; - - private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME; - - private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; - - private List metadataFields = new ArrayList<>(); - - private Builder() { - } - - /** - * Configures the Redis index name to use. - * @param name the index name to use - * @return this builder - */ - public Builder withIndexName(String name) { - this.indexName = name; - return this; - } - - /** - * Configures the Redis key prefix to use (default: "embedding:"). - * @param prefix the prefix to use - * @return this builder - */ - public Builder withPrefix(String prefix) { - this.prefix = prefix; - return this; - } - - /** - * Configures the Redis content field name to use. - * @param name the content field name to use - * @return this builder - */ - public Builder withContentFieldName(String name) { - this.contentFieldName = name; - return this; - } - - /** - * Configures the Redis embedding field name to use. - * @param name the embedding field name to use - * @return this builder - */ - public Builder withEmbeddingFieldName(String name) { - this.embeddingFieldName = name; - return this; - } - - /** - * Configures the Redis vector algorithmto use. - * @param algorithm the vector algorithm to use - * @return this builder - */ - public Builder withVectorAlgorithm(Algorithm algorithm) { - this.vectorAlgorithm = algorithm; - return this; - } - - public Builder withMetadataFields(MetadataField... fields) { - return withMetadataFields(Arrays.asList(fields)); - } - - public Builder withMetadataFields(List fields) { - this.metadataFields = fields; - return this; - } - - /** - * {@return the immutable configuration} - */ - public RedisVectorStoreConfig build() { - - return new RedisVectorStoreConfig(this); - } - - } - - } - - private final boolean initializeSchema; - public static final String DEFAULT_INDEX_NAME = "spring-ai-index"; public static final String DEFAULT_CONTENT_FIELD_NAME = "content"; @@ -256,6 +97,8 @@ public RedisVectorStoreConfig build() { public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW; + public static final String DISTANCE_FIELD_NAME = "vector_score"; + private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]"; private static final Path2 JSON_SET_PATH = Path2.of("$"); @@ -272,20 +115,20 @@ public RedisVectorStoreConfig build() { private static final String EMBEDDING_PARAM_NAME = "BLOB"; - public static final String DISTANCE_FIELD_NAME = "vector_score"; - private static final String DEFAULT_DISTANCE_METRIC = "COSINE"; + private final boolean initializeSchema; + private final JedisPooled jedis; private final EmbeddingModel embeddingModel; private final RedisVectorStoreConfig config; - private FilterExpressionConverter filterExpressionConverter; - private final BatchingStrategy batchingStrategy; + private FilterExpressionConverter filterExpressionConverter; + public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema) { @@ -475,7 +318,7 @@ private SchemaField schemaField(MetadataField field) { } private VectorAlgorithm vectorAlgorithm() { - if (config.vectorAlgorithm == Algorithm.HSNW) { + if (this.config.vectorAlgorithm == Algorithm.HSNW) { return VectorAlgorithm.HNSW; } return VectorAlgorithm.FLAT; @@ -497,4 +340,161 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str } -} \ No newline at end of file + public enum Algorithm { + + FLAT, HSNW + + } + + public record MetadataField(String name, FieldType fieldType) { + + public static MetadataField text(String name) { + return new MetadataField(name, FieldType.TEXT); + } + + public static MetadataField numeric(String name) { + return new MetadataField(name, FieldType.NUMERIC); + } + + public static MetadataField tag(String name) { + return new MetadataField(name, FieldType.TAG); + } + + } + + /** + * Configuration for the Redis vector store. + */ + public static final class RedisVectorStoreConfig { + + private final String indexName; + + private final String prefix; + + private final String contentFieldName; + + private final String embeddingFieldName; + + private final Algorithm vectorAlgorithm; + + private final List metadataFields; + + private RedisVectorStoreConfig() { + this(builder()); + } + + private RedisVectorStoreConfig(Builder builder) { + this.indexName = builder.indexName; + this.prefix = builder.prefix; + this.contentFieldName = builder.contentFieldName; + this.embeddingFieldName = builder.embeddingFieldName; + this.vectorAlgorithm = builder.vectorAlgorithm; + this.metadataFields = builder.metadataFields; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + + return new Builder(); + } + + /** + * {@return the default config} + */ + public static RedisVectorStoreConfig defaultConfig() { + + return builder().build(); + } + + public static class Builder { + + private String indexName = DEFAULT_INDEX_NAME; + + private String prefix = DEFAULT_PREFIX; + + private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME; + + private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME; + + private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; + + private List metadataFields = new ArrayList<>(); + + private Builder() { + } + + /** + * Configures the Redis index name to use. + * @param name the index name to use + * @return this builder + */ + public Builder withIndexName(String name) { + this.indexName = name; + return this; + } + + /** + * Configures the Redis key prefix to use (default: "embedding:"). + * @param prefix the prefix to use + * @return this builder + */ + public Builder withPrefix(String prefix) { + this.prefix = prefix; + return this; + } + + /** + * Configures the Redis content field name to use. + * @param name the content field name to use + * @return this builder + */ + public Builder withContentFieldName(String name) { + this.contentFieldName = name; + return this; + } + + /** + * Configures the Redis embedding field name to use. + * @param name the embedding field name to use + * @return this builder + */ + public Builder withEmbeddingFieldName(String name) { + this.embeddingFieldName = name; + return this; + } + + /** + * Configures the Redis vector algorithmto use. + * @param algorithm the vector algorithm to use + * @return this builder + */ + public Builder withVectorAlgorithm(Algorithm algorithm) { + this.vectorAlgorithm = algorithm; + return this; + } + + public Builder withMetadataFields(MetadataField... fields) { + return withMetadataFields(Arrays.asList(fields)); + } + + public Builder withMetadataFields(List fields) { + this.metadataFields = fields; + return this; + } + + /** + * {@return the immutable configuration} + */ + public RedisVectorStoreConfig build() { + + return new RedisVectorStoreConfig(this); + } + + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java index f1e96534c77..c2c0901ba0c 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.RedisVectorStore.MetadataField; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; + import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag; import static org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric; +import static org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; @@ -27,16 +39,6 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; -import java.util.Arrays; -import java.util.List; - -import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.RedisVectorStore.MetadataField; -import org.springframework.ai.vectorstore.filter.Filter.Expression; -import org.springframework.ai.vectorstore.filter.Filter.Group; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.Filter.Value; - /** * @author Julien Ruaux */ diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java index 44497602d67..124d76d5885 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,8 +23,13 @@ import java.util.Map; import java.util.UUID; +import com.redis.testcontainers.RedisStackContainer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -40,11 +44,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.redis.testcontainers.RedisStackContainer; -import redis.clients.jedis.JedisPooled; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Julien Ruaux @@ -93,24 +94,24 @@ void ensureIndexGetsCreated() { @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).isEmpty(); @@ -120,7 +121,7 @@ void addAndSearch() { @Test void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -174,7 +175,7 @@ void searchWithFilters() throws InterruptedException { @Test void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -215,11 +216,11 @@ void documentUpdate() { @Test void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); @@ -237,7 +238,7 @@ void searchWithThreshold() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME); diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java index 64b30ce05be..2d2ed538cb2 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import com.redis.testcontainers.RedisStackContainer; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -44,15 +51,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import com.redis.testcontainers.RedisStackContainer; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import redis.clients.jedis.JedisPooled; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -65,6 +65,11 @@ public class RedisVectorStoreObservationIT { static RedisStackContainer redisContainer = new RedisStackContainer( RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.data.redis.url=" + redisContainer.getRedisURI()); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -80,11 +85,6 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.data.redis.url=" + redisContainer.getRedisURI()); - @BeforeEach void cleanDatabase() { this.contextRunner.run(context -> context.getBean(RedisVectorStore.class).getJedis().flushAll()); @@ -93,13 +93,13 @@ void cleanDatabase() { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-typesense-store/pom.xml b/vector-stores/spring-ai-typesense-store/pom.xml index b43f7711d58..14087b76be1 100644 --- a/vector-stores/spring-ai-typesense-store/pom.xml +++ b/vector-stores/spring-ai-typesense-store/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java index 0f19340d8c3..5706dd15416 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter; @@ -40,7 +56,7 @@ private String getOperationSymbol(Filter.Expression exp) { return " "; // in typesense "IN" operator looks like -> country: [USA, UK] case NIN: return " != "; // in typesense "NIN" operator looks like -> country: - // !=[USA, UK] + // !=[USA, UK] default: throw new RuntimeException("Not supported expression type:" + exp.type()); } diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java index c27161a0f84..1a7d0b36d38 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -23,8 +23,20 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.typesense.api.Client; +import org.typesense.api.FieldTypes; +import org.typesense.model.CollectionResponse; +import org.typesense.model.CollectionSchema; +import org.typesense.model.DeleteDocumentsParameters; +import org.typesense.model.Field; +import org.typesense.model.ImportDocumentsParameters; +import org.typesense.model.MultiSearchCollectionParameters; +import org.typesense.model.MultiSearchResult; +import org.typesense.model.MultiSearchSearchesParameter; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -38,18 +50,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import org.typesense.api.Client; -import org.typesense.api.FieldTypes; -import org.typesense.model.CollectionResponse; -import org.typesense.model.CollectionSchema; -import org.typesense.model.DeleteDocumentsParameters; -import org.typesense.model.Field; -import org.typesense.model.ImportDocumentsParameters; -import org.typesense.model.MultiSearchCollectionParameters; -import org.typesense.model.MultiSearchResult; -import org.typesense.model.MultiSearchSearchesParameter; - -import io.micrometer.observation.ObservationRegistry; /** * @author Pablo Sanchidrian Herrera @@ -58,8 +58,6 @@ */ public class TypesenseVectorStore extends AbstractObservationVectorStore implements InitializingBean { - private static final Logger logger = LoggerFactory.getLogger(TypesenseVectorStore.class); - /** * The name of the field that contains the document ID. It is mandatory to set "id" as * the field name because that is the name that typesense is going to look for. @@ -78,88 +76,20 @@ public class TypesenseVectorStore extends AbstractObservationVectorStore impleme public static final int INVALID_EMBEDDING_DIMENSION = -1; + private static final Logger logger = LoggerFactory.getLogger(TypesenseVectorStore.class); + + public final FilterExpressionConverter filterExpressionConverter = new TypesenseFilterExpressionConverter(); + private final Client client; private final EmbeddingModel embeddingModel; private final TypesenseVectorStoreConfig config; - public final FilterExpressionConverter filterExpressionConverter = new TypesenseFilterExpressionConverter(); - private final boolean initializeSchema; private final BatchingStrategy batchingStrategy; - public static class TypesenseVectorStoreConfig { - - private final String collectionName; - - private final int embeddingDimension; - - public TypesenseVectorStoreConfig(String collectionName, int embeddingDimension) { - this.collectionName = collectionName; - this.embeddingDimension = embeddingDimension; - } - - /** - * {@return the default config} - */ - public static TypesenseVectorStoreConfig defaultConfig() { - return builder().build(); - } - - private TypesenseVectorStoreConfig(Builder builder) { - this.collectionName = builder.collectionName; - this.embeddingDimension = builder.embeddingDimension; - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - - return new Builder(); - } - - public static class Builder { - - private String collectionName; - - private int embeddingDimension; - - /** - * Set the collection name. - * @param collectionName The collection name. - * @return The builder. - */ - public Builder withCollectionName(String collectionName) { - this.collectionName = collectionName; - return this; - } - - /** - * Set the embedding dimension. - * @param embeddingDimension The embedding dimension. - * @return The builder. - */ - public Builder withEmbeddingDimension(int embeddingDimension) { - this.embeddingDimension = embeddingDimension; - return this; - } - - /** - * Build the configuration. - * @return The configuration. - */ - public TypesenseVectorStoreConfig build() { - return new TypesenseVectorStoreConfig(this); - } - - } - - } - public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel) { this(client, embeddingModel, TypesenseVectorStoreConfig.defaultConfig(), false); } @@ -396,4 +326,74 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); } + public static class TypesenseVectorStoreConfig { + + private final String collectionName; + + private final int embeddingDimension; + + public TypesenseVectorStoreConfig(String collectionName, int embeddingDimension) { + this.collectionName = collectionName; + this.embeddingDimension = embeddingDimension; + } + + private TypesenseVectorStoreConfig(Builder builder) { + this.collectionName = builder.collectionName; + this.embeddingDimension = builder.embeddingDimension; + } + + /** + * {@return the default config} + */ + public static TypesenseVectorStoreConfig defaultConfig() { + return builder().build(); + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + + return new Builder(); + } + + public static class Builder { + + private String collectionName; + + private int embeddingDimension; + + /** + * Set the collection name. + * @param collectionName The collection name. + * @return The builder. + */ + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * Set the embedding dimension. + * @param embeddingDimension The embedding dimension. + * @return The builder. + */ + public Builder withEmbeddingDimension(int embeddingDimension) { + this.embeddingDimension = embeddingDimension; + return this; + } + + /** + * Build the configuration. + * @return The configuration. + */ + public TypesenseVectorStoreConfig build() { + return new TypesenseVectorStoreConfig(this); + } + + } + + } + } diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java index ac27de2a8db..ea06769bc9e 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java index 044b0114aa5..23cd9f03bb6 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,7 +16,23 @@ package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.typesense.api.Client; +import org.typesense.api.Configuration; +import org.typesense.resources.Node; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -27,21 +43,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.typesense.api.Client; -import org.typesense.api.Configuration; -import org.typesense.resources.Node; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; @@ -79,7 +80,7 @@ public static String getText(String uri) { @Test void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")); @@ -127,10 +128,10 @@ void documentUpdate() { @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Map info = ((TypesenseVectorStore) vectorStore).getCollectionInfo(); @@ -146,7 +147,7 @@ void addAndSearch() { @Test void searchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -201,11 +202,11 @@ void searchWithFilters() { @Test void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); @@ -221,7 +222,7 @@ void searchWithThreshold() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java index c5fdc881029..fefde3d3eb5 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,7 +23,17 @@ import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.typesense.api.Client; +import org.typesense.api.Configuration; +import org.typesense.resources.Node; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -41,16 +50,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.typesense.api.Client; -import org.typesense.api.Configuration; -import org.typesense.resources.Node; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -87,13 +88,13 @@ public static String getText(String uri) { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-weaviate-store/pom.xml b/vector-stores/spring-ai-weaviate-store/pom.xml index c9359ca0889..c1ea5ff95be 100644 --- a/vector-stores/spring-ai-weaviate-store/pom.xml +++ b/vector-stores/spring-ai-weaviate-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java index 5dcb1caeb7d..08eb2f29f8d 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.Date; @@ -37,11 +38,11 @@ */ public class WeaviateFilterExpressionConverter extends AbstractFilterExpressionConverter { - private boolean mapIntegerToNumberValue = true; - // https://weaviate.io/developers/weaviate/api/graphql/filters#special-cases private static final List SYSTEM_IDENTIFIERS = List.of("id", "_creationTimeUnix", "_lastUpdateTimeUnix"); + private boolean mapIntegerToNumberValue = true; + private List allowedIdentifierNames; public WeaviateFilterExpressionConverter(List allowedIdentifierNames) { @@ -189,4 +190,4 @@ protected void doGroup(Group group, StringBuilder context) { context); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java index 0c381b4e754..58a831eb10d 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -24,25 +24,8 @@ import java.util.Optional; import java.util.stream.Collectors; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.model.EmbeddingUtils; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; -import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; - import io.micrometer.observation.ObservationRegistry; import io.weaviate.client.WeaviateClient; import io.weaviate.client.base.Result; @@ -61,6 +44,23 @@ import io.weaviate.client.v1.graphql.query.fields.Field; import io.weaviate.client.v1.graphql.query.fields.Fields; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + /** * A VectorStore implementation backed by Weaviate vector database. * @@ -130,164 +130,6 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore { */ private final ObjectMapper objectMapper = new ObjectMapper(); - /** - * Configuration class for the WeaviateVectorStore. - */ - public static final class WeaviateVectorStoreConfig { - - public record MetadataField(String name, Type type) { - public enum Type { - - TEXT, NUMBER, BOOLEAN - - } - - public static MetadataField text(String name) { - return new MetadataField(name, Type.TEXT); - } - - public static MetadataField number(String name) { - return new MetadataField(name, Type.NUMBER); - } - - public static MetadataField bool(String name) { - return new MetadataField(name, Type.BOOLEAN); - } - - } - - /** - * https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-consistency-strategies - */ - public enum ConsistentLevel { - - /** - * Write must receive an acknowledgement from at least one replica node. This - * is the fastest (most available), but least consistent option. - */ - ONE, - - /** - * Write must receive an acknowledgement from at least QUORUM replica nodes. - * QUORUM is calculated as n / 2 + 1, where n is the number of replicas. - */ - QUORUM, - - /** - * Write must receive an acknowledgement from all replica nodes. This is the - * most consistent, but 'slowest'. - */ - ALL - - } - - private final String weaviateObjectClass; - - private final ConsistentLevel consistencyLevel; - - /** - * Known metadata fields to add as a fields to the Weaviate schema. You can add - * arbitrary metadata with your documents but only the metadata fields listed here - * can be used in the expression filters. - */ - private final List filterMetadataFields; - - private final Map headers; - - /** - * Constructor using the builder. - * @param builder The configuration builder. - */ - public WeaviateVectorStoreConfig(Builder builder) { - this.weaviateObjectClass = builder.objectClass; - this.consistencyLevel = builder.consistencyLevel; - this.filterMetadataFields = builder.filterMetadataFields; - this.headers = builder.headers; - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - return new Builder(); - } - - /** - * {@return the default config} - */ - public static WeaviateVectorStoreConfig defaultConfig() { - return builder().build(); - } - - public static class Builder { - - private String objectClass = "SpringAiWeaviate"; - - private ConsistentLevel consistencyLevel = WeaviateVectorStoreConfig.ConsistentLevel.ONE; - - private List filterMetadataFields = List.of(); - - private Map headers = Map.of(); - - private Builder() { - } - - /** - * Weaviate known, filterable metadata fields. - * @param filterMetadataFields known metadata fields to use. - * @return this builder. - */ - public Builder withFilterableMetadataFields(List filterMetadataFields) { - Assert.notNull(filterMetadataFields, "The filterMetadataFields can not be null."); - this.filterMetadataFields = filterMetadataFields; - return this; - } - - /** - * Weaviate config headers. - * @param headers config headers to use. - * @return this builder. - */ - public Builder withHeaders(Map headers) { - Assert.notNull(headers, "The headers can not be null."); - this.headers = headers; - return this; - } - - /** - * Weaviate objectClass. - * @param objectClass objectClass to use. - * @return this builder. - */ - public Builder withObjectClass(String objectClass) { - Assert.hasText(objectClass, "The objectClass can not be empty."); - this.objectClass = objectClass; - return this; - } - - /** - * Weaviate consistencyLevel. - * @param consistencyLevel consistencyLevel to use. - * @return this builder. - */ - public Builder withConsistencyLevel(ConsistentLevel consistencyLevel) { - Assert.notNull(consistencyLevel, "The consistencyLevel can not be null."); - this.consistencyLevel = consistencyLevel; - return this; - } - - /** - * {@return the immutable configuration} - */ - public WeaviateVectorStoreConfig build() { - return new WeaviateVectorStoreConfig(this); - } - - } - - } - /** * Constructs a new WeaviateVectorStore. * @param vectorStoreConfig The configuration for the store. @@ -462,7 +304,7 @@ public List doSimilaritySearch(SearchRequest request) { .build()) .limit(request.getTopK()) .withWhereFilter(WhereArgument.builder().build()) // adds an empty 'where:{}' - // placeholder. + // placeholder. .fields(Fields.builder().fields(this.weaviateSimilaritySearchFields).build()); String graphQLQuery = queryBuilder.build().buildQuery(); @@ -554,4 +396,163 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str .withCollectionName(this.weaviateObjectClass); } -} \ No newline at end of file + /** + * Configuration class for the WeaviateVectorStore. + */ + public static final class WeaviateVectorStoreConfig { + + private final String weaviateObjectClass; + + private final ConsistentLevel consistencyLevel; + + /** + * Known metadata fields to add as a fields to the Weaviate schema. You can add + * arbitrary metadata with your documents but only the metadata fields listed here + * can be used in the expression filters. + */ + private final List filterMetadataFields; + + private final Map headers; + + /** + * Constructor using the builder. + * @param builder The configuration builder. + */ + public WeaviateVectorStoreConfig(Builder builder) { + this.weaviateObjectClass = builder.objectClass; + this.consistencyLevel = builder.consistencyLevel; + this.filterMetadataFields = builder.filterMetadataFields; + this.headers = builder.headers; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * {@return the default config} + */ + public static WeaviateVectorStoreConfig defaultConfig() { + return builder().build(); + } + + /** + * https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-consistency-strategies + */ + public enum ConsistentLevel { + + /** + * Write must receive an acknowledgement from at least one replica node. This + * is the fastest (most available), but least consistent option. + */ + ONE, + + /** + * Write must receive an acknowledgement from at least QUORUM replica nodes. + * QUORUM is calculated as n / 2 + 1, where n is the number of replicas. + */ + QUORUM, + + /** + * Write must receive an acknowledgement from all replica nodes. This is the + * most consistent, but 'slowest'. + */ + ALL + + } + + public record MetadataField(String name, Type type) { + + public static MetadataField text(String name) { + return new MetadataField(name, Type.TEXT); + } + + public static MetadataField number(String name) { + return new MetadataField(name, Type.NUMBER); + } + + public static MetadataField bool(String name) { + return new MetadataField(name, Type.BOOLEAN); + } + + public enum Type { + + TEXT, NUMBER, BOOLEAN + + } + + } + + public static class Builder { + + private String objectClass = "SpringAiWeaviate"; + + private ConsistentLevel consistencyLevel = WeaviateVectorStoreConfig.ConsistentLevel.ONE; + + private List filterMetadataFields = List.of(); + + private Map headers = Map.of(); + + private Builder() { + } + + /** + * Weaviate known, filterable metadata fields. + * @param filterMetadataFields known metadata fields to use. + * @return this builder. + */ + public Builder withFilterableMetadataFields(List filterMetadataFields) { + Assert.notNull(filterMetadataFields, "The filterMetadataFields can not be null."); + this.filterMetadataFields = filterMetadataFields; + return this; + } + + /** + * Weaviate config headers. + * @param headers config headers to use. + * @return this builder. + */ + public Builder withHeaders(Map headers) { + Assert.notNull(headers, "The headers can not be null."); + this.headers = headers; + return this; + } + + /** + * Weaviate objectClass. + * @param objectClass objectClass to use. + * @return this builder. + */ + public Builder withObjectClass(String objectClass) { + Assert.hasText(objectClass, "The objectClass can not be empty."); + this.objectClass = objectClass; + return this; + } + + /** + * Weaviate consistencyLevel. + * @param consistencyLevel consistencyLevel to use. + * @return this builder. + */ + public Builder withConsistencyLevel(ConsistentLevel consistencyLevel) { + Assert.notNull(consistencyLevel, "The consistencyLevel can not be null."); + this.consistencyLevel = consistencyLevel; + return this; + } + + /** + * {@return the immutable configuration} + */ + public WeaviateVectorStoreConfig build() { + return new WeaviateVectorStoreConfig(this); + } + + } + + } + +} diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java index fc004b3e165..53275f567e7 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java index 81e78bc5c70..3dbfcdd930c 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java index 6f57d7a5d81..b474cdaeded 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,7 +23,14 @@ import java.util.Map; import java.util.UUID; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -35,13 +41,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.weaviate.WeaviateContainer; -import io.weaviate.client.Config; -import io.weaviate.client.WeaviateClient; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -78,32 +79,32 @@ public static String getText(String uri) { } private void resetCollection(VectorStore vectorStore) { - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); } @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); resetCollection(vectorStore); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(0); @@ -113,7 +114,7 @@ public void addAndSearch() { @Test public void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -167,7 +168,7 @@ public void searchWithFilters() throws InterruptedException { @Test public void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -210,13 +211,13 @@ public void documentUpdate() { @Test public void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); resetCollection(vectorStore); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); @@ -234,7 +235,7 @@ public void searchWithThreshold() { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); @@ -267,4 +268,4 @@ public EmbeddingModel embeddingModel() { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java index 35f9c4d599f..17b54c1881c 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import io.weaviate.client.WeaviateClient; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -38,15 +46,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.weaviate.WeaviateContainer; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import io.weaviate.client.WeaviateClient; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -59,6 +60,9 @@ public class WeaviateVectorStoreObservationIT { static WeaviateContainer weaviateContainer = new WeaviateContainer(WeaviateImage.DEFAULT_IMAGE) .waitingFor(Wait.forHttp("/v1/.well-known/ready").forPort(8080)); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -74,19 +78,16 @@ public static String getText(String uri) { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation()