diff --git a/pom.xml b/pom.xml index 73811933859..6863d729cf4 100644 --- a/pom.xml +++ b/pom.xml @@ -36,6 +36,7 @@ vector-stores/spring-ai-chroma vector-stores/spring-ai-azure vector-stores/spring-ai-weaviate + spring-ai-vertex-ai @@ -81,7 +82,7 @@ 17 - 3.1.3 + 3.2.0 4.0.2 0.16.0 1.0.0-beta.3 diff --git a/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/client/AzureOpenAiClientMetadataTests.java b/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/client/AzureOpenAiClientMetadataTests.java index b840793e96d..8163461c4fe 100644 --- a/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/client/AzureOpenAiClientMetadataTests.java +++ b/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/client/AzureOpenAiClientMetadataTests.java @@ -81,7 +81,7 @@ void azureOpenAiMetadataCapturedDuringGeneration() { Generation generation = response.getGeneration(); assertThat(generation).isNotNull() - .extracting(Generation::getText) + .extracting(Generation::getContent) .isEqualTo("No! You will actually land with a resounding thud. This is the way!"); assertPromptMetadata(response); diff --git a/spring-ai-core/pom.xml b/spring-ai-core/pom.xml index 85ca743657f..bbcf76cb8df 100644 --- a/spring-ai-core/pom.xml +++ b/spring-ai-core/pom.xml @@ -39,6 +39,12 @@ + + org.springframework + spring-webflux + 6.1.1 + + org.springframework spring-messaging diff --git a/spring-ai-core/src/main/java/org/springframework/ai/client/AiClient.java b/spring-ai-core/src/main/java/org/springframework/ai/client/AiClient.java index 56af725afb4..d41a768e744 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/client/AiClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/client/AiClient.java @@ -24,7 +24,7 @@ public interface AiClient { default String generate(String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return generate(prompt).getGeneration().getText(); + return generate(prompt).getGeneration().getContent(); } AiResponse generate(Prompt prompt); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java index a5f7e600b7d..c8984ca72ac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java @@ -15,26 +15,42 @@ */ package org.springframework.ai.client; -import java.util.Collections; import java.util.List; -import java.util.Map; import org.springframework.ai.metadata.GenerationMetadata; import org.springframework.ai.metadata.PromptMetadata; import org.springframework.lang.Nullable; +/** + * The chat completion (e.g. generation) response returned by an AI provider. + */ public class AiResponse { private final GenerationMetadata metadata; + /** + * List of generated messages returned by the AI provider. + */ private final List generations; private PromptMetadata promptMetadata; + /** + * Construct a new {@link AiResponse} instance without metadata. + * @param generations the {@link List} of {@link Generation} returned by the AI + * provider. + */ public AiResponse(List generations) { this(generations, GenerationMetadata.NULL); } + /** + * Construct a new {@link AiResponse} instance. + * @param generations the {@link List} of {@link Generation} returned by the AI + * provider. + * @param metadata {@link GenerationMetadata} containing information about the use of + * the AI provider's API. + */ public AiResponse(List generations, GenerationMetadata metadata) { this.metadata = metadata; this.generations = List.copyOf(generations); @@ -51,23 +67,22 @@ public List getGenerations() { return this.generations; } + /** + * @return Returns the first {@link Generation} in the generations list. + */ public Generation getGeneration() { return this.generations.get(0); } /** - * Returns {@link GenerationMetadata} containing information about the use of the AI - * provider's API. - * @return {@link GenerationMetadata} containing information about the use of the AI - * provider's API. + * @return Returns {@link GenerationMetadata} containing information about the use of + * the AI provider's API. */ public GenerationMetadata getGenerationMetadata() { return this.metadata; } /** - * Returns {@link PromptMetadata} containing information on prompt processing by the - * AI. * @return {@link PromptMetadata} containing information on prompt processing by the * AI. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java index 06507b3845a..adcedfe52a3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java @@ -20,14 +20,14 @@ import java.util.Map; import org.springframework.ai.metadata.ChoiceMetadata; +import org.springframework.ai.prompt.messages.AbstractMessage; +import org.springframework.ai.prompt.messages.MessageType; import org.springframework.lang.Nullable; -public class Generation { - - // Just text for now - private final String text; - - private Map info; +/** + * Represents a response returned by the AI. + */ +public class Generation extends AbstractMessage { private ChoiceMetadata choiceMetadata; @@ -35,17 +35,12 @@ public Generation(String text) { this(text, Collections.emptyMap()); } - public Generation(String text, Map info) { - this.text = text; - this.info = Map.copyOf(info); - } - - public String getText() { - return this.text; + public Generation(String content, Map properties) { + super(MessageType.ASSISTANT, content, properties); } - public Map getInfo() { - return this.info; + public Generation(String content, Map properties, MessageType type) { + super(type, content, properties); } public ChoiceMetadata getChoiceMetadata() { @@ -60,7 +55,7 @@ public Generation withChoiceMetadata(@Nullable ChoiceMetadata choiceMetadata) { @Override public String toString() { - return "Generation{" + "text='" + text + '\'' + ", info=" + info + '}'; + return "Generation{" + "text='" + content + '\'' + ", info=" + properties + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java index 75234219b7c..68ec99df0c6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java @@ -86,4 +86,40 @@ public String getMessageTypeValue() { return this.messageType.getValue(); } + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((content == null) ? 0 : content.hashCode()); + result = prime * result + ((properties == null) ? 0 : properties.hashCode()); + result = prime * result + ((messageType == null) ? 0 : messageType.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + AbstractMessage other = (AbstractMessage) obj; + if (content == null) { + if (other.content != null) + return false; + } + else if (!content.equals(other.content)) + return false; + if (properties == null) { + if (other.properties != null) + return false; + } + else if (!properties.equals(other.properties)) + return false; + if (messageType != other.messageType) + return false; + return true; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java index 6496c666be7..f5039b5f36a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java @@ -65,7 +65,7 @@ public List apply(List documents) { var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount)); Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getContent())); - String keywords = this.aiClient.generate(prompt).getGeneration().getText(); + String keywords = this.aiClient.generate(prompt).getGeneration().getContent(); document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords)); } return documents; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java index 3ceb622445c..f43904a0bcd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java @@ -102,7 +102,7 @@ public List apply(List documents) { Prompt prompt = new PromptTemplate(this.summaryTemplate) .create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContext)); - documentSummaries.add(this.aiClient.generate(prompt).getGeneration().getText()); + documentSummaries.add(this.aiClient.generate(prompt).getGeneration().getContent()); } for (int i = 0; i < documentSummaries.size(); i++) { diff --git a/spring-ai-core/src/test/java/org/springframework/ai/client/AiClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/client/AiClientTests.java index 2d5d2a99e7e..f673e1543a9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/client/AiClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/client/AiClientTests.java @@ -71,7 +71,7 @@ void generateWithStringCallsGenerateWithPromptAndReturnsResponseCorrectly() { verify(mockClient, times(1)).generate(eq(userMessage)); verify(mockClient, times(1)).generate(isA(Prompt.class)); verify(response, times(1)).getGeneration(); - verify(generation, times(1)).getText(); + verify(generation, times(1)).getContent(); verifyNoMoreInteractions(mockClient, generation, response); } diff --git a/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java b/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java index 67b088b9ed6..59153fce7eb 100644 --- a/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java +++ b/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java @@ -47,7 +47,7 @@ void helloWorldCompletion() { """; Prompt prompt = new Prompt(mistral7bInstruct); AiResponse aiResponse = huggingfaceAiClient.generate(prompt); - assertThat(aiResponse.getGeneration().getText()).isNotEmpty(); + assertThat(aiResponse.getGeneration().getContent()).isNotEmpty(); String expectedResponse = """ ```json { @@ -56,9 +56,9 @@ void helloWorldCompletion() { "address": "#1 Samuel St." } ```"""; - assertThat(aiResponse.getGeneration().getText()).isEqualTo(expectedResponse); - assertThat(aiResponse.getGeneration().getInfo()).containsKey("generated_tokens"); - assertThat(aiResponse.getGeneration().getInfo()).containsEntry("generated_tokens", 39); + assertThat(aiResponse.getGeneration().getContent()).isEqualTo(expectedResponse); + assertThat(aiResponse.getGeneration().getProperties()).containsKey("generated_tokens"); + assertThat(aiResponse.getGeneration().getProperties()).containsEntry("generated_tokens", 39); } diff --git a/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/client/OllamaClientTests.java b/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/client/OllamaClientTests.java index a3e668f43f6..4abb94230ba 100644 --- a/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/client/OllamaClientTests.java +++ b/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/client/OllamaClientTests.java @@ -21,7 +21,7 @@ public void smokeTest() { Assertions.assertNotNull(aiResponse); Assertions.assertFalse(CollectionUtils.isEmpty(aiResponse.getGenerations())); Assertions.assertNotNull(aiResponse.getGeneration()); - Assertions.assertNotNull(aiResponse.getGeneration().getText()); + Assertions.assertNotNull(aiResponse.getGeneration().getContent()); } private static OllamaClient getOllamaClient() { diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/OpenAiClientIT.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/OpenAiClientIT.java index a46f7bcfe4a..a7412ded937 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/OpenAiClientIT.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/OpenAiClientIT.java @@ -61,7 +61,7 @@ void outputParser() { Prompt prompt = new Prompt(promptTemplate.createMessage()); Generation generation = this.openAiClient.generate(prompt).getGeneration(); - List list = outputParser.parse(generation.getText()); + List list = outputParser.parse(generation.getContent()); assertThat(list).hasSize(5); } @@ -80,7 +80,7 @@ void mapOutputParser() { Prompt prompt = new Prompt(promptTemplate.createMessage()); Generation generation = openAiClient.generate(prompt).getGeneration(); - Map result = outputParser.parse(generation.getText()); + Map result = outputParser.parse(generation.getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -99,7 +99,7 @@ void beanOutputParser() { Prompt prompt = new Prompt(promptTemplate.createMessage()); Generation generation = openAiClient.generate(prompt).getGeneration(); - ActorsFilms actorsFilms = outputParser.parse(generation.getText()); + ActorsFilms actorsFilms = outputParser.parse(generation.getContent()); } record ActorsFilmsRecord(String actor, List movies) { @@ -119,7 +119,7 @@ void beanOutputParserRecords() { Prompt prompt = new Prompt(promptTemplate.createMessage()); Generation generation = openAiClient.generate(prompt).getGeneration(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getText()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index 8cec9f7e0b6..38e50a6c81d 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -39,7 +39,7 @@ public abstract class AbstractIT { protected void evaluateQuestionAndAnswer(String question, AiResponse response, boolean factBased) { assertThat(response).isNotNull(); - String answer = response.getGeneration().getText(); + String answer = response.getGeneration().getContent(); logger.info("Question: " + question); logger.info("Answer:" + answer); PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, @@ -53,12 +53,12 @@ protected void evaluateQuestionAndAnswer(String question, AiResponse response, b } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiClient.generate(prompt).getGeneration().getText(); + String yesOrNo = openAiClient.generate(prompt).getGeneration().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiClient.generate(prompt).getGeneration().getText(); + String reasonForFailure = openAiClient.generate(prompt).getGeneration().getContent(); fail(reasonForFailure); } else { diff --git a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java index 25aeaec352d..c65e2cce5e2 100644 --- a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java +++ b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java @@ -55,7 +55,7 @@ public class BasicEvaluationTest { protected void evaluateQuestionAndAnswer(String question, AiResponse response, boolean factBased) { assertThat(response).isNotNull(); - String answer = response.getGeneration().getText(); + String answer = response.getGeneration().getContent(); logger.info("Question: " + question); logger.info("Answer:" + answer); PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, @@ -69,12 +69,12 @@ protected void evaluateQuestionAndAnswer(String question, AiResponse response, b } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiClient.generate(prompt).getGeneration().getText(); + String yesOrNo = openAiClient.generate(prompt).getGeneration().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiClient.generate(prompt).getGeneration().getText(); + String reasonForFailure = openAiClient.generate(prompt).getGeneration().getContent(); fail(reasonForFailure); } else { diff --git a/spring-ai-vertex-ai/pom.xml b/spring-ai-vertex-ai/pom.xml new file mode 100644 index 00000000000..219e74d967c --- /dev/null +++ b/spring-ai-vertex-ai/pom.xml @@ -0,0 +1,71 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 0.8.0-SNAPSHOT + + spring-ai-vertex-ai + jar + Spring AI Vertex AI + Vertex AI support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + com.google.cloud + google-cloud-aiplatform + 3.32.0 + + + commons-logging + commons-logging + + + + + + org.springframework + spring-web + 6.1.1 + + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java new file mode 100644 index 00000000000..c6035358941 --- /dev/null +++ b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java @@ -0,0 +1,614 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertex.api; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +// @formatter:off +/** + * Vertex AI API client for the Generative Language model. + * https://developers.generativeai.google/api/rest/generativelanguage + * https://cloud.google.com/vertex-ai/docs/generative-ai/learn/streaming + * + * Provides methods to generate a response from the model given an input + * https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage + * + * as well as to generate embeddings for the input text: + * https://developers.generativeai.google/api/rest/generativelanguage/models/embedText + * + * + * Supported models: + * + *
+ * name=models/chat-bison-001,
+ * 		version=001,
+ * 		displayName=Chat Bison,
+ * 		description=Chat-optimized generative language model.,
+ * 		inputTokenLimit=4096,
+ * 		outputTokenLimit=1024,
+ * 		supportedGenerationMethods=[generateMessage, countMessageTokens],
+ * 		temperature=0.25,
+ * 		topP=0.95,
+ *		topK=40
+ *
+ * name=models/text-bison-001,
+ *		version=001,
+ *		displayName=Text Bison,
+ *		description=Model targeted for text generation.,
+ * 		inputTokenLimit=8196,
+ *		outputTokenLimit=1024,
+ *		supportedGenerationMethods=[generateText, countTextTokens, createTunedTextModel],
+ *		temperature=0.7,
+ *		topP=0.95,
+ *		topK=40
+ *
+ * name=models/embedding-gecko-001,
+ * 		version=001,
+ * 		displayName=Embedding Gecko, description=Obtain a distributed representation of a text.,
+ * 		inputTokenLimit=1024,
+ * 		outputTokenLimit=1,
+ * 		supportedGenerationMethods=[embedText, countTextTokens],
+ * 		temperature=null,
+ * 		topP=null,
+ * 		topK=null
+ * 
+ * + * @author Christian Tzolov + */ +public class VertexAiApi { + + /** + * The default generation model. This model is used to generate responses for the + * input text. + */ + public static final String DEFAULT_GENERATE_MODEL = "chat-bison-001"; + + /** + * The default embedding model. This model is used to generate embeddings for the + * input text. + */ + public static final String DEFAULT_EMBEDDING_MODEL = "embedding-gecko-001"; + + private static final String DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta3"; + + private final RestClient restClient; + + private final String apiKey; + + private final String generateModel; + + private final String embeddingModel; + + /** + * Create an new chat completion api. + * @param apiKey vertex apiKey. + */ + public VertexAiApi(String apiKey) { + this(DEFAULT_BASE_URL, apiKey, DEFAULT_GENERATE_MODEL, DEFAULT_EMBEDDING_MODEL, RestClient.builder()); + } + + /** + * Create an new chat completion api. + * @param baseUrl api base URL. + * @param apiKey vertex apiKey. + * @param model vertex model. + * @param embeddingModel vertex embedding model. + * @param restClientBuilder RestClient builder. + */ + public VertexAiApi(String baseUrl, String apiKey, String model, String embeddingModel, + RestClient.Builder restClientBuilder) { + + this.generateModel = model; + this.embeddingModel = embeddingModel; + this.apiKey = apiKey; + + Consumer jsonContentHeaders = headers -> { + headers.setAccept(List.of(MediaType.APPLICATION_JSON)); + headers.setContentType(MediaType.APPLICATION_JSON); + }; + + ResponseErrorHandler responseErrorHandler = new ResponseErrorHandler() { + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + throw new RuntimeException(String.format("%s - %s", response.getStatusCode().value(), + new ObjectMapper().readValue(response.getBody(), ResponseError.class))); + } + } + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(jsonContentHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + } + + /** + * Generates a response from the model given an input. + * @param request Request body. + * @return Response body. + */ + public GenerateMessageResponse generateMessage(GenerateMessageRequest request) { + Assert.notNull(request, "The request body can not be null."); + + return this.restClient.post() + .uri("/models/{model}:generateMessage?key={apiKey}", this.generateModel, this.apiKey) + .body(request) + .retrieve() + .body(GenerateMessageResponse.class); + } + + /** + * Generates a response from the model given an input. + * @param text Text to embed. + * @return Embedding response. + */ + public Embedding embedText(String text) { + Assert.hasText(text, "The text can not be null or empty."); + + @JsonInclude(Include.NON_NULL) + record EmbeddingResponse(Embedding embedding) { + } + + return this.restClient.post() + .uri("/models/{model}:embedText?key={apiKey}", this.embeddingModel, this.apiKey) + .body(Map.of("text", text)) + .retrieve() + .body(EmbeddingResponse.class) + .embedding(); + } + + /** + * Generates a response from the model given an input. + * @param texts List of texts to embed. + * @return Embedding response containing a list of embeddings. + */ + public List batchEmbedText(List texts) { + Assert.notNull(texts, "The texts can not be null."); + + @JsonInclude(Include.NON_NULL) + record BatchEmbeddingResponse(List embeddings) { + } + + return this.restClient.post() + .uri("/models/{model}:batchEmbedText?key={apiKey}", this.embeddingModel, this.apiKey) + // https://developers.generativeai.google/api/rest/generativelanguage/models/batchEmbedText#request-body + .body(Map.of("texts", texts)) + .retrieve() + .body(BatchEmbeddingResponse.class) + .embeddings(); + } + + /** + * Returns the number of tokens in the message prompt. + * @param prompt Message prompt to count tokens for. + * @return Number of tokens in the message prompt. + */ + public Integer countMessageTokens(MessagePrompt prompt) { + + Assert.notNull(prompt, "The message prompt can not be null."); + + record TokenCount(@JsonProperty("tokenCount") Integer tokenCount) { + } + + return this.restClient.post() + .uri("/models/{model}:countMessageTokens?key={apiKey}", this.generateModel, this.apiKey) + .body(Map.of("prompt", prompt)) + .retrieve() + .body(TokenCount.class) + .tokenCount(); + } + + /** + * Returns the list of models available for use. + * @return List of models available for use. + */ + public List listModels() { + + @JsonInclude(Include.NON_NULL) + record ModelList(@JsonProperty("models") List models) { + record ModelName(String name) { + } + } + + return this.restClient.get() + .uri("/models?key={apiKey}", this.apiKey) + .retrieve() + .body(ModelList.class) + .models() + .stream() + .map(ModelList.ModelName::name) + .toList(); + } + + /** + * Returns the model details. + * @param modelName Name of the model to get details for. + * @return Model details. + */ + public Model getModel(String modelName) { + + Assert.hasText(modelName, "The model name can not be null or empty."); + + if (modelName.startsWith("models/")) { + modelName = modelName.substring("models/".length()); + } + + return this.restClient.get() + .uri("/models/{model}?key={apiKey}", modelName, this.apiKey) + .retrieve() + .body(Model.class); + } + + /** + * API error response. + * + * @param error Error details. + */ + @JsonInclude(Include.NON_NULL) + public record ResponseError( + @JsonProperty("error") Error error) { + + /** + * Error details. + * + * @param message Error message. + * @param code Error code. + * @param status Error status. + */ + @JsonInclude(Include.NON_NULL) + public record Error( + @JsonProperty("message") String message, + @JsonProperty("code") String code, + @JsonProperty("status") String status) { + } + } + + /** + * Information about a Generative Language Model. + * + * @param name The resource name of the Model. Format: `models/{model} with a {model} + * naming convention of:` + * + *
+	 * {baseModelId}-{version}
+	 * 
+ * @param baseModelId The name of the base model, pass this to the generation request. + * @param version The version of the model. This represents the major version. + * @param displayName The human-readable name of the model. E.g. "Chat Bison". The + * name can be up to 128 characters long and can consist of any UTF-8 characters. + * @param description A short description of the model. + * @param inputTokenLimit Maximum number of input tokens allowed for this model. + * @param outputTokenLimit Maximum number of output tokens allowed for this model. + * @param supportedGenerationMethods List of supported generation methods for this + * model. The method names are defined as Pascal case strings, such as generateMessage + * which correspond to API methods. + * @param temperature Controls the randomness of the output. Values can range over + * [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more + * varied, while a value closer to 0.0 will typically result in less surprising + * responses from the model. This value specifies default to be used by the backend + * while making the call to the model. + * @param topP For Nucleus sampling. Nucleus sampling considers the smallest set of + * tokens whose probability sum is at least topP. This value specifies default to be + * used by the backend while making the call to the model. + * @param topK For Top-k sampling. Top-k sampling considers the set of topK most + * probable tokens. This value specifies default to be used by the backend while + * making the call to the model. + */ + @JsonInclude(Include.NON_NULL) + public record Model( + @JsonProperty("name") String name, + @JsonProperty("baseModelId") String baseModelId, + @JsonProperty("version") String version, + @JsonProperty("displayName") String displayName, + @JsonProperty("description") String description, + @JsonProperty("inputTokenLimit") Integer inputTokenLimit, + @JsonProperty("outputTokenLimit") Integer outputTokenLimit, + @JsonProperty("supportedGenerationMethods") List supportedGenerationMethods, + @JsonProperty("temperature") Float temperature, + @JsonProperty("topP") Float topP, + @JsonProperty("topK") Integer topK) { + } + + /** + * A list of floats representing the embedding. + * + * @param value The embedding values. + */ + @JsonInclude(Include.NON_NULL) + public record Embedding( + @JsonProperty("value") List value) { + + } + + /** + * The base unit of structured text. A Message includes an author and the content of + * the Message. The author is used to tag messages when they are fed to the model as + * text. + * + * @param author (optional) Author of the message. This serves as a key for tagging + * the content of this Message when it is fed to the model as text.The author can be + * any alphanumeric string. + * @param content The text content of the structured Message. + * @param citationMetadata (output only) Citation information for model-generated + * content in this Message. If this Message was generated as output from the model, + * this field may be populated with attribution information for any text included in + * the content. This field is used only on output. + */ + @JsonInclude(Include.NON_NULL) + public record Message( + @JsonProperty("author") String author, + @JsonProperty("content") String content, + @JsonProperty("citationMetadata") CitationMetadata citationMetadata) { + + /** + * Short-hand constructor for a message without citation metadata. + * @param author (optional) Author of the message. + * @param content The text content of the structured Message. + */ + public Message(String author, String content) { + this(author, content, null); + } + + /** + * A collection of source attributions for a piece of content. + * + * Citations to sources for a specific response. + */ + @JsonInclude(Include.NON_NULL) + public record CitationMetadata( + @JsonProperty("citationSources") List citationSources) { + } + + /** + * A citation to a source for a portion of a specific response. + * + * @param startIndex (optional) Start of segment of the response that is + * attributed to this source. Index indicates the start of the segment, measured + * in bytes. + * @param endIndex (optional) End of the attributed segment, exclusive. + * @param uri (optional) URI that is attributed as a source for a portion of the + * text. + * @param license (optional) License for the GitHub project that is attributed as + * a source for segment.License info is required for code citations. + */ + @JsonInclude(Include.NON_NULL) + public record CitationSource( + @JsonProperty("startIndex") Integer startIndex, + @JsonProperty("endIndex") Integer endIndex, + @JsonProperty("uri") String uri, + @JsonProperty("license") String license) { + } + } + + /** + * All of the structured input text passed to the model as a prompt. + * + * A MessagePrompt contains a structured set of fields that provide context for the + * conversation, examples of user input/model output message pairs that prime the + * model to respond in different ways, and the conversation history or list of + * messages representing the alternating turns of the conversation between the user + * and the model. + * + * @param context (optional) Text that should be provided to the model first to ground + * the response. If not empty, this context will be given to the model first before + * the examples and messages. When using a context be sure to provide it with every + * request to maintain continuity. This field can be a description of your prompt to + * the model to help provide context and guide the responses. Examples: "Translate the + * phrase from English to French." or "Given a statement, classify the sentiment as + * happy, sad or neutral." Anything included in this field will take precedence over + * message history if the total input size exceeds the model's inputTokenLimit and the + * input request is truncated. + * @param examples (optional) Examples of what the model should generate. This + * includes both user input and the response that the model should emulate. These + * examples are treated identically to conversation messages except that they take + * precedence over the history in messages: If the total input size exceeds the + * model's inputTokenLimit the input will be truncated. Items will be dropped from + * messages before examples. + * @param messages (optional) A snapshot of the recent conversation history sorted + * chronologically. Turns alternate between two authors. If the total input size + * exceeds the model's inputTokenLimit the input will be truncated: The oldest items + * will be dropped from messages. + */ + @JsonInclude(Include.NON_NULL) + public record MessagePrompt( + @JsonProperty("context") String context, + @JsonProperty("examples") List examples, + @JsonProperty("messages") List messages) { + + /** + * Shortcut constructor for a message prompt without context. + * @param messages The conversation history used by the model. + */ + public MessagePrompt(List messages) { + this(null, null, messages); + } + + /** + * Shortcut constructor for a message prompt without context. + * @param context An input/output example used to instruct the Model. It + * demonstrates how the model should respond or format its response. + * @param messages The conversation history used by the model. + */ + public MessagePrompt(String context, List messages) { + this(context, null, messages); + } + + /** + * An input/output example used to instruct the Model. It demonstrates how the + * model should respond or format its response. + * + * @param input An example of an input Message from the user. + * @param output An example of an output Message from the model. + */ + @JsonInclude(Include.NON_NULL) + public record Example( + @JsonProperty("input") Message input, + @JsonProperty("output") Message output) { + } + } + + /** + * Message generation request body. + * + * @param prompt The structured textual input given to the model as a prompt. Given a + * prompt, the model will return what it predicts is the next message in the + * discussion. + * @param temperature (optional) Controls the randomness of the output. Values can + * range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that + * are more varied, while a value closer to 0.0 will typically result in less + * surprising responses from the model. + * @param candidateCount (optional) The number of generated response messages to + * return. This value must be between [1, 8], inclusive. If unset, this will default + * to 1. + * @param topP (optional) The maximum cumulative probability of tokens to consider + * when sampling. The model uses combined Top-k and nucleus sampling. Nucleus sampling + * considers the smallest set of tokens whose probability sum is at least topP. + * @param topK (optional) The maximum number of tokens to consider when sampling. The + * model uses combined Top-k and nucleus sampling. Top-k sampling considers the set of + * topK most probable tokens. + */ + @JsonInclude(Include.NON_NULL) + public record GenerateMessageRequest( + @JsonProperty("prompt") MessagePrompt prompt, + @JsonProperty("temperature") Float temperature, + @JsonProperty("candidateCount") Integer candidateCount, + @JsonProperty("topP") Float topP, + @JsonProperty("topK") Integer topK) { + + /** + * Shortcut constructor to create a GenerateMessageRequest with only the prompt + * parameter. + * @param prompt The structured textual input given to the model as a prompt. + */ + public GenerateMessageRequest(MessagePrompt prompt) { + this(prompt, null, null, null, null); + } + + /** + * Shortcut constructor to create a GenerateMessageRequest with only the prompt + * and temperature parameters. + * @param prompt The structured textual input given to the model as a prompt. + * @param temperature (optional) Controls the randomness of the output. + * @param topK (optional) The maximum number of tokens to consider when sampling. + */ + public GenerateMessageRequest(MessagePrompt prompt, Float temperature, Integer topK) { + this(prompt, temperature, null, null, topK); + } + } + + /** + * The response from the model. This includes candidate messages and conversation + * history in the form of chronologically-ordered messages. + * + * @param candidates Candidate response messages from the model. + * @param messages The conversation history used by the model. + * @param filters A set of content filtering metadata for the prompt and response + * text. This indicates which SafetyCategory(s) blocked a candidate from this + * response, the lowest HarmProbability that triggered a block, and the HarmThreshold + * setting for that category. + */ + @JsonInclude(Include.NON_NULL) + public record GenerateMessageResponse( + @JsonProperty("candidates") List candidates, + @JsonProperty("messages") List messages, + @JsonProperty("filters") List filters) { + + /** + * Content filtering metadata associated with processing a single request. It + * contains a reason and an optional supporting string. The reason may be + * unspecified. + * + * @param reason The reason content was blocked during request processing. + * @param message A string that describes the filtering behavior in more detail. + */ + @JsonInclude(Include.NON_NULL) + public record ContentFilter( + @JsonProperty("reason") BlockedReason reason, + @JsonProperty("message") String message) { + + /** + * Reasons why content may have been blocked. + */ + public enum BlockedReason { + + /** + * A blocked reason was not specified. + */ + BLOCKED_REASON_UNSPECIFIED, + /** + * Content was blocked by safety settings. + */ + SAFETY, + /** + * Content was blocked, but the reason is uncategorized. + */ + OTHER + + } + } + } + + /** + * Main method to test the VertexAiApi. + * @param args blank. + */ + public static void main(String[] args) { + VertexAiApi vertexAiApi = new VertexAiApi(System.getenv("PALM_API_KEY")); + + var prompt = new MessagePrompt(List.of(new Message("0", "Hello, how are you?"))); + + GenerateMessageRequest request = new GenerateMessageRequest(prompt); + + GenerateMessageResponse response = vertexAiApi.generateMessage(request); + + System.out.println(response); + + System.out.println(vertexAiApi.embedText("Hello, how are you?")); + + System.out.println(vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!"))); + + System.out.println(vertexAiApi.countMessageTokens(prompt)); + + System.out.println(vertexAiApi.listModels()); + + System.out.println(vertexAiApi.listModels().stream().map(vertexAiApi::getModel).toList()); + + } + +} +// @formatter:on \ No newline at end of file diff --git a/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClient.java b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClient.java new file mode 100644 index 00000000000..bb18bdf7e19 --- /dev/null +++ b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClient.java @@ -0,0 +1,67 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertex.embedding; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.vertex.api.VertexAiApi; + +/** + * @author Christian Tzolov + */ +public class VertexAiEmbeddingClient implements EmbeddingClient { + + private final VertexAiApi vertexAiApi; + + public VertexAiEmbeddingClient(VertexAiApi vertexAiApi) { + this.vertexAiApi = vertexAiApi; + } + + @Override + public List embed(String text) { + return this.vertexAiApi.embedText(text).value(); + } + + @Override + public List embed(Document document) { + return embed(document.getContent()); + } + + @Override + public List> embed(List texts) { + List vertexEmbeddings = this.vertexAiApi.batchEmbedText(texts); + return vertexEmbeddings.stream().map(e -> e.value()).toList(); + } + + @Override + public EmbeddingResponse embedForResponse(List texts) { + List vertexEmbeddings = this.vertexAiApi.batchEmbedText(texts); + int index = 0; + List embeddings = new ArrayList<>(); + for (VertexAiApi.Embedding vertexEmbedding : vertexEmbeddings) { + embeddings.add(new Embedding(vertexEmbedding.value(), index++)); + } + return new EmbeddingResponse(embeddings, Map.of()); + } + +} diff --git a/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClient.java b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClient.java new file mode 100644 index 00000000000..b51fe674da6 --- /dev/null +++ b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClient.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertex.generation; + +import java.util.List; +import java.util.stream.Collectors; + +import org.springframework.ai.client.AiClient; +import org.springframework.ai.client.AiResponse; +import org.springframework.ai.client.Generation; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.vertex.api.VertexAiApi; +import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest; +import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse; +import org.springframework.ai.vertex.api.VertexAiApi.MessagePrompt; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +/** + * @author Christian Tzolov + */ +public class VertexAiChatGenerationClient implements AiClient { + + private final VertexAiApi vertexAiApi; + + private Float temperature; + + private Float topP; + + private Integer candidateCount; + + private Integer maxTokens; + + public VertexAiChatGenerationClient(VertexAiApi vertexAiApi) { + this.vertexAiApi = vertexAiApi; + } + + @Override + public AiResponse generate(Prompt prompt) { + + String vertexContext = prompt.getMessages() + .stream() + .filter(m -> m.getMessageType() == MessageType.SYSTEM) + .map(m -> m.getContent()) + .collect(Collectors.joining("\n")); + + List vertexMessages = prompt.getMessages() + .stream() + .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) + .map(m -> new VertexAiApi.Message(m.getMessageType().getValue(), m.getContent())) + .toList(); + + Assert.isTrue(!CollectionUtils.isEmpty(vertexMessages), "No user or assistant messages found in the prompt!"); + + var vertexPrompt = new MessagePrompt(vertexContext, vertexMessages); + + GenerateMessageRequest request = new GenerateMessageRequest(vertexPrompt, this.temperature, this.candidateCount, + this.topP, this.maxTokens); + + GenerateMessageResponse response = this.vertexAiApi.generateMessage(request); + + List generations = response.candidates() + .stream() + .map(vmsg -> new Generation(vmsg.content())) + .collect(Collectors.toList()); + + return new AiResponse(generations); + } + +} diff --git a/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiIT.java b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiIT.java new file mode 100644 index 00000000000..533d810a417 --- /dev/null +++ b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiIT.java @@ -0,0 +1,125 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertex.api; + +import java.util.List; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.vertex.api.VertexAiApi.Embedding; +import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest; +import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse; +import org.springframework.ai.vertex.api.VertexAiApi.MessagePrompt; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link VertexAiApi}. Requires a valid API key to be set via the + * {@code PALM_API_KEY} environment and at the moment Google enables is it only in the US + * region (so use VPN for testing). + * + * @author Christian Tzolov + */ +@EnabledIfEnvironmentVariable(named = "PALM_API_KEY", matches = ".*") +public class VertexAiApiIT { + + VertexAiApi vertexAiApi = new VertexAiApi(System.getenv("PALM_API_KEY")); + + @Test + public void generateMessage() { + + var prompt = new MessagePrompt(List.of(new VertexAiApi.Message("0", "Hello, how are you?"))); + + GenerateMessageRequest request = new GenerateMessageRequest(prompt); + + GenerateMessageResponse response = vertexAiApi.generateMessage(request); + + assertThat(response).isNotNull(); + + // Vertex returns the prompt messages in the response's messages list. + assertThat(response.messages()).hasSize(1); + assertThat(response.messages().get(0)).isEqualTo(prompt.messages().get(0)); + + // Vertex returns the answer in the response's candidates list. + assertThat(response.candidates()).hasSize(1); + assertThat(response.candidates().get(0).author()).isNotBlank(); + assertThat(response.candidates().get(0).content()).isNotBlank(); + } + + @Test + public void embedText() { + + var text = "Hello, how are you?"; + + Embedding response = vertexAiApi.embedText(text); + + assertThat(response).isNotNull(); + assertThat(response.value()).hasSize(768); + } + + @Test + public void batchEmbedText() { + + var text = List.of("Hello, how are you?", "I am fine, thank you!"); + + List response = vertexAiApi.batchEmbedText(text); + + assertThat(response).isNotNull(); + assertThat(response).hasSize(2); + assertThat(response.get(0).value()).hasSize(768); + assertThat(response.get(1).value()).hasSize(768); + } + + @Test + public void countMessageTokens() { + + var text = "Hello, how are you?"; + + var prompt = new MessagePrompt(List.of(new VertexAiApi.Message("0", text))); + int response = vertexAiApi.countMessageTokens(prompt); + + assertThat(response).isEqualTo(17); + } + + @Test + public void listModels() { + + List response = vertexAiApi.listModels(); + + assertThat(response).isNotNull(); + assertThat(response).hasSizeGreaterThan(0); + assertThat(response).contains("models/chat-bison-001", "models/text-bison-001", "models/embedding-gecko-001"); + + System.out.println(" - " + response.stream() + .map(vertexAiApi::getModel) + .map(VertexAiApi.Model::toString) + .collect(Collectors.joining("\n - "))); + } + + @Test + public void getModel() { + + VertexAiApi.Model model = vertexAiApi.getModel("models/chat-bison-001"); + + System.out.println(model); + assertThat(model).isNotNull(); + assertThat(model.displayName()).isEqualTo("Chat Bison"); + } + +} diff --git a/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java new file mode 100644 index 00000000000..78645d27a2a --- /dev/null +++ b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java @@ -0,0 +1,150 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertex.api; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vertex.api.VertexAiApi.Embedding; +import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest; +import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse; +import org.springframework.ai.vertex.api.VertexAiApi.MessagePrompt; +import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse.ContentFilter.BlockedReason; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; +import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.content; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestToUriTemplate; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; + +/** + * @author Christian Tzolov + */ +@RestClientTest(VertexAiApiTests.Config.class) +public class VertexAiApiTests { + + private final static String TEST_API_KEY = "test-api-key"; + + @Autowired + private VertexAiApi client; + + @Autowired + private MockRestServiceServer server; + + @Autowired + private ObjectMapper objectMapper; + + @AfterEach + void resetMockServer() { + server.reset(); + } + + @Test + public void generateMessage() throws JsonProcessingException { + + GenerateMessageRequest request = new GenerateMessageRequest( + new MessagePrompt(List.of(new VertexAiApi.Message("0", "Hello, how are you?")))); + + GenerateMessageResponse expectedResponse = new GenerateMessageResponse( + List.of(new VertexAiApi.Message("1", "Hello, how are you?")), + List.of(new VertexAiApi.Message("0", "I'm fine, thank you.")), + List.of(new VertexAiApi.GenerateMessageResponse.ContentFilter(BlockedReason.SAFETY, "reason"))); + + server + .expect(requestToUriTemplate("/models/{model}:generateMessage?key={apiKey}", + VertexAiApi.DEFAULT_GENERATE_MODEL, TEST_API_KEY)) + .andExpect(method(HttpMethod.POST)) + .andExpect(content().json(objectMapper.writeValueAsString(request))) + .andRespond(withSuccess(objectMapper.writeValueAsString(expectedResponse), MediaType.APPLICATION_JSON)); + + GenerateMessageResponse response = client.generateMessage(request); + + assertThat(response).isEqualTo(expectedResponse); + + server.verify(); + } + + @Test + public void embedText() throws JsonProcessingException { + + String text = "Hello, how are you?"; + + Embedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3)); + + server + .expect(requestToUriTemplate("/models/{model}:embedText?key={apiKey}", VertexAiApi.DEFAULT_EMBEDDING_MODEL, + TEST_API_KEY)) + .andExpect(method(HttpMethod.POST)) + .andExpect(content().json(objectMapper.writeValueAsString(Map.of("text", text)))) + .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embedding", expectedEmbedding)), + MediaType.APPLICATION_JSON)); + + Embedding embedding = client.embedText(text); + + assertThat(embedding).isEqualTo(expectedEmbedding); + + server.verify(); + } + + @Test + public void batchEmbedText() throws JsonProcessingException { + + List texts = List.of("Hello, how are you?", "I'm fine, thank you."); + + List expectedEmbeddings = List.of(new Embedding(List.of(0.1, 0.2, 0.3)), + new Embedding(List.of(0.4, 0.5, 0.6))); + + server + .expect(requestToUriTemplate("/models/{model}:batchEmbedText?key={apiKey}", + VertexAiApi.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) + .andExpect(method(HttpMethod.POST)) + .andExpect(content().json(objectMapper.writeValueAsString(Map.of("texts", texts)))) + .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embeddings", expectedEmbeddings)), + MediaType.APPLICATION_JSON)); + + List embeddings = client.batchEmbedText(texts); + + assertThat(embeddings).isEqualTo(expectedEmbeddings); + + server.verify(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public VertexAiApi audioApi(RestClient.Builder builder) { + return new VertexAiApi("", TEST_API_KEY, VertexAiApi.DEFAULT_GENERATE_MODEL, + VertexAiApi.DEFAULT_EMBEDDING_MODEL, builder); + } + + } + +} diff --git a/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java new file mode 100644 index 00000000000..78ebbfc308a --- /dev/null +++ b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java @@ -0,0 +1,48 @@ +package org.springframework.ai.vertex.embedding; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.vertex.api.VertexAiApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "PALM_API_KEY", matches = ".*") +class VertexAiEmbeddingClientIT { + + @Autowired + private VertexAiEmbeddingClient embeddingClient; + + @Test + void simpleEmbedding() { + assertThat(embeddingClient).isNotNull(); + EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); + assertThat(embeddingResponse.getData()).hasSize(1); + assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingClient.dimensions()).isEqualTo(768); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public VertexAiApi vertexAiApi() { + return new VertexAiApi(System.getenv("PALM_API_KEY")); + } + + @Bean + public VertexAiEmbeddingClient vertexAiEmbedding(VertexAiApi vertexAiApi) { + return new VertexAiEmbeddingClient(vertexAiApi); + } + + } + +} diff --git a/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java new file mode 100644 index 00000000000..9c5a610067d --- /dev/null +++ b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java @@ -0,0 +1,130 @@ +package org.springframework.ai.vertex.generation; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.client.AiResponse; +import org.springframework.ai.client.Generation; +import org.springframework.ai.parser.BeanOutputParser; +import org.springframework.ai.parser.ListOutputParser; +import org.springframework.ai.parser.MapOutputParser; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.PromptTemplate; +import org.springframework.ai.prompt.SystemPromptTemplate; +import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.vertex.api.VertexAiApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "PALM_API_KEY", matches = ".*") +class VertexAiChatGenerationClientIT { + + @Autowired + private VertexAiChatGenerationClient client; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void roleTest() { + String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; + String name = "Bob"; + String voice = "pirate"; + UserMessage userMessage = new UserMessage(request); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + AiResponse response = client.generate(prompt); + assertThat(response.getGeneration().getContent()).contains("Bartholomew"); + } + + // @Test + void outputParser() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputParser outputParser = new ListOutputParser(conversionService); + + String format = outputParser.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors.", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.client.generate(prompt).getGeneration(); + + List list = outputParser.parse(generation.getContent()); + assertThat(list).hasSize(5); + + } + + // @Test + void mapOutputParser() { + MapOutputParser outputParser = new MapOutputParser(); + + String format = outputParser.getFormat(); + String template = """ + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = client.generate(prompt).getGeneration(); + + Map result = outputParser.parse(generation.getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + // @Test + void beanOutputParserRecords() { + + BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); + + String format = outputParser.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = client.generate(prompt).getGeneration(); + + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public VertexAiApi vertexAiApi() { + return new VertexAiApi(System.getenv("PALM_API_KEY")); + } + + @Bean + public VertexAiChatGenerationClient vertexAiEmbedding(VertexAiApi vertexAiApi) { + return new VertexAiChatGenerationClient(vertexAiApi); + } + + } + +} diff --git a/spring-ai-vertex-ai/src/test/resources/prompts/system-message.st b/spring-ai-vertex-ai/src/test/resources/prompts/system-message.st new file mode 100644 index 00000000000..dc2cf2dcd84 --- /dev/null +++ b/spring-ai-vertex-ai/src/test/resources/prompts/system-message.st @@ -0,0 +1,4 @@ +"You are a helpful AI assistant. Your name is {name}. +You are an AI assistant that helps people find information. +Your name is {name} +You should reply to the user's request with your name and also in the style of a {voice}. \ No newline at end of file