movies) {
@@ -119,7 +119,7 @@ void beanOutputParserRecords() {
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = openAiClient.generate(prompt).getGeneration();
- ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getText());
+ ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent());
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
}
diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java
index 8cec9f7e0b6..38e50a6c81d 100644
--- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java
+++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java
@@ -39,7 +39,7 @@ public abstract class AbstractIT {
protected void evaluateQuestionAndAnswer(String question, AiResponse response, boolean factBased) {
assertThat(response).isNotNull();
- String answer = response.getGeneration().getText();
+ String answer = response.getGeneration().getContent();
logger.info("Question: " + question);
logger.info("Answer:" + answer);
PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource,
@@ -53,12 +53,12 @@ protected void evaluateQuestionAndAnswer(String question, AiResponse response, b
}
Message userMessage = userPromptTemplate.createMessage();
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
- String yesOrNo = openAiClient.generate(prompt).getGeneration().getText();
+ String yesOrNo = openAiClient.generate(prompt).getGeneration().getContent();
logger.info("Is Answer related to question: " + yesOrNo);
if (yesOrNo.equalsIgnoreCase("no")) {
SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource);
prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage));
- String reasonForFailure = openAiClient.generate(prompt).getGeneration().getText();
+ String reasonForFailure = openAiClient.generate(prompt).getGeneration().getContent();
fail(reasonForFailure);
}
else {
diff --git a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java
index 25aeaec352d..c65e2cce5e2 100644
--- a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java
+++ b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java
@@ -55,7 +55,7 @@ public class BasicEvaluationTest {
protected void evaluateQuestionAndAnswer(String question, AiResponse response, boolean factBased) {
assertThat(response).isNotNull();
- String answer = response.getGeneration().getText();
+ String answer = response.getGeneration().getContent();
logger.info("Question: " + question);
logger.info("Answer:" + answer);
PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource,
@@ -69,12 +69,12 @@ protected void evaluateQuestionAndAnswer(String question, AiResponse response, b
}
Message userMessage = userPromptTemplate.createMessage();
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
- String yesOrNo = openAiClient.generate(prompt).getGeneration().getText();
+ String yesOrNo = openAiClient.generate(prompt).getGeneration().getContent();
logger.info("Is Answer related to question: " + yesOrNo);
if (yesOrNo.equalsIgnoreCase("no")) {
SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource);
prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage));
- String reasonForFailure = openAiClient.generate(prompt).getGeneration().getText();
+ String reasonForFailure = openAiClient.generate(prompt).getGeneration().getContent();
fail(reasonForFailure);
}
else {
diff --git a/spring-ai-vertex-ai/pom.xml b/spring-ai-vertex-ai/pom.xml
new file mode 100644
index 00000000000..219e74d967c
--- /dev/null
+++ b/spring-ai-vertex-ai/pom.xml
@@ -0,0 +1,71 @@
+
+
+ 4.0.0
+
+ org.springframework.ai
+ spring-ai
+ 0.8.0-SNAPSHOT
+
+ spring-ai-vertex-ai
+ jar
+ Spring AI Vertex AI
+ Vertex AI support
+ https://github.com/spring-projects/spring-ai
+
+
+ https://github.com/spring-projects/spring-ai
+ git://github.com/spring-projects/spring-ai.git
+ git@github.com:spring-projects/spring-ai.git
+
+
+
+
+
+
+ org.springframework.ai
+ spring-ai-core
+ ${project.parent.version}
+
+
+
+ com.google.cloud
+ google-cloud-aiplatform
+ 3.32.0
+
+
+ commons-logging
+ commons-logging
+
+
+
+
+
+ org.springframework
+ spring-web
+ 6.1.1
+
+
+
+
+
+ org.springframework
+ spring-context-support
+
+
+
+ org.springframework.boot
+ spring-boot-starter-logging
+
+
+
+
+ org.springframework.ai
+ spring-ai-test
+ ${project.version}
+ test
+
+
+
+
+
diff --git a/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java
new file mode 100644
index 00000000000..c6035358941
--- /dev/null
+++ b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java
@@ -0,0 +1,614 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.vertex.api;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonInclude.Include;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.http.client.ClientHttpResponse;
+import org.springframework.util.Assert;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestClient;
+
+// @formatter:off
+/**
+ * Vertex AI API client for the Generative Language model.
+ * https://developers.generativeai.google/api/rest/generativelanguage
+ * https://cloud.google.com/vertex-ai/docs/generative-ai/learn/streaming
+ *
+ * Provides methods to generate a response from the model given an input
+ * https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
+ *
+ * as well as to generate embeddings for the input text:
+ * https://developers.generativeai.google/api/rest/generativelanguage/models/embedText
+ *
+ *
+ * Supported models:
+ *
+ *
+ * name=models/chat-bison-001,
+ * version=001,
+ * displayName=Chat Bison,
+ * description=Chat-optimized generative language model.,
+ * inputTokenLimit=4096,
+ * outputTokenLimit=1024,
+ * supportedGenerationMethods=[generateMessage, countMessageTokens],
+ * temperature=0.25,
+ * topP=0.95,
+ * topK=40
+ *
+ * name=models/text-bison-001,
+ * version=001,
+ * displayName=Text Bison,
+ * description=Model targeted for text generation.,
+ * inputTokenLimit=8196,
+ * outputTokenLimit=1024,
+ * supportedGenerationMethods=[generateText, countTextTokens, createTunedTextModel],
+ * temperature=0.7,
+ * topP=0.95,
+ * topK=40
+ *
+ * name=models/embedding-gecko-001,
+ * version=001,
+ * displayName=Embedding Gecko, description=Obtain a distributed representation of a text.,
+ * inputTokenLimit=1024,
+ * outputTokenLimit=1,
+ * supportedGenerationMethods=[embedText, countTextTokens],
+ * temperature=null,
+ * topP=null,
+ * topK=null
+ *
+ *
+ * @author Christian Tzolov
+ */
+public class VertexAiApi {
+
+ /**
+ * The default generation model. This model is used to generate responses for the
+ * input text.
+ */
+ public static final String DEFAULT_GENERATE_MODEL = "chat-bison-001";
+
+ /**
+ * The default embedding model. This model is used to generate embeddings for the
+ * input text.
+ */
+ public static final String DEFAULT_EMBEDDING_MODEL = "embedding-gecko-001";
+
+ private static final String DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta3";
+
+ private final RestClient restClient;
+
+ private final String apiKey;
+
+ private final String generateModel;
+
+ private final String embeddingModel;
+
+ /**
+ * Create an new chat completion api.
+ * @param apiKey vertex apiKey.
+ */
+ public VertexAiApi(String apiKey) {
+ this(DEFAULT_BASE_URL, apiKey, DEFAULT_GENERATE_MODEL, DEFAULT_EMBEDDING_MODEL, RestClient.builder());
+ }
+
+ /**
+ * Create an new chat completion api.
+ * @param baseUrl api base URL.
+ * @param apiKey vertex apiKey.
+ * @param model vertex model.
+ * @param embeddingModel vertex embedding model.
+ * @param restClientBuilder RestClient builder.
+ */
+ public VertexAiApi(String baseUrl, String apiKey, String model, String embeddingModel,
+ RestClient.Builder restClientBuilder) {
+
+ this.generateModel = model;
+ this.embeddingModel = embeddingModel;
+ this.apiKey = apiKey;
+
+ Consumer jsonContentHeaders = headers -> {
+ headers.setAccept(List.of(MediaType.APPLICATION_JSON));
+ headers.setContentType(MediaType.APPLICATION_JSON);
+ };
+
+ ResponseErrorHandler responseErrorHandler = new ResponseErrorHandler() {
+ @Override
+ public boolean hasError(ClientHttpResponse response) throws IOException {
+ return response.getStatusCode().isError();
+ }
+
+ @Override
+ public void handleError(ClientHttpResponse response) throws IOException {
+ if (response.getStatusCode().isError()) {
+ throw new RuntimeException(String.format("%s - %s", response.getStatusCode().value(),
+ new ObjectMapper().readValue(response.getBody(), ResponseError.class)));
+ }
+ }
+ };
+
+ this.restClient = restClientBuilder.baseUrl(baseUrl)
+ .defaultHeaders(jsonContentHeaders)
+ .defaultStatusHandler(responseErrorHandler)
+ .build();
+ }
+
+ /**
+ * Generates a response from the model given an input.
+ * @param request Request body.
+ * @return Response body.
+ */
+ public GenerateMessageResponse generateMessage(GenerateMessageRequest request) {
+ Assert.notNull(request, "The request body can not be null.");
+
+ return this.restClient.post()
+ .uri("/models/{model}:generateMessage?key={apiKey}", this.generateModel, this.apiKey)
+ .body(request)
+ .retrieve()
+ .body(GenerateMessageResponse.class);
+ }
+
+ /**
+ * Generates a response from the model given an input.
+ * @param text Text to embed.
+ * @return Embedding response.
+ */
+ public Embedding embedText(String text) {
+ Assert.hasText(text, "The text can not be null or empty.");
+
+ @JsonInclude(Include.NON_NULL)
+ record EmbeddingResponse(Embedding embedding) {
+ }
+
+ return this.restClient.post()
+ .uri("/models/{model}:embedText?key={apiKey}", this.embeddingModel, this.apiKey)
+ .body(Map.of("text", text))
+ .retrieve()
+ .body(EmbeddingResponse.class)
+ .embedding();
+ }
+
+ /**
+ * Generates a response from the model given an input.
+ * @param texts List of texts to embed.
+ * @return Embedding response containing a list of embeddings.
+ */
+ public List batchEmbedText(List texts) {
+ Assert.notNull(texts, "The texts can not be null.");
+
+ @JsonInclude(Include.NON_NULL)
+ record BatchEmbeddingResponse(List embeddings) {
+ }
+
+ return this.restClient.post()
+ .uri("/models/{model}:batchEmbedText?key={apiKey}", this.embeddingModel, this.apiKey)
+ // https://developers.generativeai.google/api/rest/generativelanguage/models/batchEmbedText#request-body
+ .body(Map.of("texts", texts))
+ .retrieve()
+ .body(BatchEmbeddingResponse.class)
+ .embeddings();
+ }
+
+ /**
+ * Returns the number of tokens in the message prompt.
+ * @param prompt Message prompt to count tokens for.
+ * @return Number of tokens in the message prompt.
+ */
+ public Integer countMessageTokens(MessagePrompt prompt) {
+
+ Assert.notNull(prompt, "The message prompt can not be null.");
+
+ record TokenCount(@JsonProperty("tokenCount") Integer tokenCount) {
+ }
+
+ return this.restClient.post()
+ .uri("/models/{model}:countMessageTokens?key={apiKey}", this.generateModel, this.apiKey)
+ .body(Map.of("prompt", prompt))
+ .retrieve()
+ .body(TokenCount.class)
+ .tokenCount();
+ }
+
+ /**
+ * Returns the list of models available for use.
+ * @return List of models available for use.
+ */
+ public List listModels() {
+
+ @JsonInclude(Include.NON_NULL)
+ record ModelList(@JsonProperty("models") List models) {
+ record ModelName(String name) {
+ }
+ }
+
+ return this.restClient.get()
+ .uri("/models?key={apiKey}", this.apiKey)
+ .retrieve()
+ .body(ModelList.class)
+ .models()
+ .stream()
+ .map(ModelList.ModelName::name)
+ .toList();
+ }
+
+ /**
+ * Returns the model details.
+ * @param modelName Name of the model to get details for.
+ * @return Model details.
+ */
+ public Model getModel(String modelName) {
+
+ Assert.hasText(modelName, "The model name can not be null or empty.");
+
+ if (modelName.startsWith("models/")) {
+ modelName = modelName.substring("models/".length());
+ }
+
+ return this.restClient.get()
+ .uri("/models/{model}?key={apiKey}", modelName, this.apiKey)
+ .retrieve()
+ .body(Model.class);
+ }
+
+ /**
+ * API error response.
+ *
+ * @param error Error details.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record ResponseError(
+ @JsonProperty("error") Error error) {
+
+ /**
+ * Error details.
+ *
+ * @param message Error message.
+ * @param code Error code.
+ * @param status Error status.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record Error(
+ @JsonProperty("message") String message,
+ @JsonProperty("code") String code,
+ @JsonProperty("status") String status) {
+ }
+ }
+
+ /**
+ * Information about a Generative Language Model.
+ *
+ * @param name The resource name of the Model. Format: `models/{model} with a {model}
+ * naming convention of:`
+ *
+ *
+ * {baseModelId}-{version}
+ *
+ * @param baseModelId The name of the base model, pass this to the generation request.
+ * @param version The version of the model. This represents the major version.
+ * @param displayName The human-readable name of the model. E.g. "Chat Bison". The
+ * name can be up to 128 characters long and can consist of any UTF-8 characters.
+ * @param description A short description of the model.
+ * @param inputTokenLimit Maximum number of input tokens allowed for this model.
+ * @param outputTokenLimit Maximum number of output tokens allowed for this model.
+ * @param supportedGenerationMethods List of supported generation methods for this
+ * model. The method names are defined as Pascal case strings, such as generateMessage
+ * which correspond to API methods.
+ * @param temperature Controls the randomness of the output. Values can range over
+ * [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more
+ * varied, while a value closer to 0.0 will typically result in less surprising
+ * responses from the model. This value specifies default to be used by the backend
+ * while making the call to the model.
+ * @param topP For Nucleus sampling. Nucleus sampling considers the smallest set of
+ * tokens whose probability sum is at least topP. This value specifies default to be
+ * used by the backend while making the call to the model.
+ * @param topK For Top-k sampling. Top-k sampling considers the set of topK most
+ * probable tokens. This value specifies default to be used by the backend while
+ * making the call to the model.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record Model(
+ @JsonProperty("name") String name,
+ @JsonProperty("baseModelId") String baseModelId,
+ @JsonProperty("version") String version,
+ @JsonProperty("displayName") String displayName,
+ @JsonProperty("description") String description,
+ @JsonProperty("inputTokenLimit") Integer inputTokenLimit,
+ @JsonProperty("outputTokenLimit") Integer outputTokenLimit,
+ @JsonProperty("supportedGenerationMethods") List supportedGenerationMethods,
+ @JsonProperty("temperature") Float temperature,
+ @JsonProperty("topP") Float topP,
+ @JsonProperty("topK") Integer topK) {
+ }
+
+ /**
+ * A list of floats representing the embedding.
+ *
+ * @param value The embedding values.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record Embedding(
+ @JsonProperty("value") List value) {
+
+ }
+
+ /**
+ * The base unit of structured text. A Message includes an author and the content of
+ * the Message. The author is used to tag messages when they are fed to the model as
+ * text.
+ *
+ * @param author (optional) Author of the message. This serves as a key for tagging
+ * the content of this Message when it is fed to the model as text.The author can be
+ * any alphanumeric string.
+ * @param content The text content of the structured Message.
+ * @param citationMetadata (output only) Citation information for model-generated
+ * content in this Message. If this Message was generated as output from the model,
+ * this field may be populated with attribution information for any text included in
+ * the content. This field is used only on output.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record Message(
+ @JsonProperty("author") String author,
+ @JsonProperty("content") String content,
+ @JsonProperty("citationMetadata") CitationMetadata citationMetadata) {
+
+ /**
+ * Short-hand constructor for a message without citation metadata.
+ * @param author (optional) Author of the message.
+ * @param content The text content of the structured Message.
+ */
+ public Message(String author, String content) {
+ this(author, content, null);
+ }
+
+ /**
+ * A collection of source attributions for a piece of content.
+ *
+ * Citations to sources for a specific response.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record CitationMetadata(
+ @JsonProperty("citationSources") List citationSources) {
+ }
+
+ /**
+ * A citation to a source for a portion of a specific response.
+ *
+ * @param startIndex (optional) Start of segment of the response that is
+ * attributed to this source. Index indicates the start of the segment, measured
+ * in bytes.
+ * @param endIndex (optional) End of the attributed segment, exclusive.
+ * @param uri (optional) URI that is attributed as a source for a portion of the
+ * text.
+ * @param license (optional) License for the GitHub project that is attributed as
+ * a source for segment.License info is required for code citations.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record CitationSource(
+ @JsonProperty("startIndex") Integer startIndex,
+ @JsonProperty("endIndex") Integer endIndex,
+ @JsonProperty("uri") String uri,
+ @JsonProperty("license") String license) {
+ }
+ }
+
+ /**
+ * All of the structured input text passed to the model as a prompt.
+ *
+ * A MessagePrompt contains a structured set of fields that provide context for the
+ * conversation, examples of user input/model output message pairs that prime the
+ * model to respond in different ways, and the conversation history or list of
+ * messages representing the alternating turns of the conversation between the user
+ * and the model.
+ *
+ * @param context (optional) Text that should be provided to the model first to ground
+ * the response. If not empty, this context will be given to the model first before
+ * the examples and messages. When using a context be sure to provide it with every
+ * request to maintain continuity. This field can be a description of your prompt to
+ * the model to help provide context and guide the responses. Examples: "Translate the
+ * phrase from English to French." or "Given a statement, classify the sentiment as
+ * happy, sad or neutral." Anything included in this field will take precedence over
+ * message history if the total input size exceeds the model's inputTokenLimit and the
+ * input request is truncated.
+ * @param examples (optional) Examples of what the model should generate. This
+ * includes both user input and the response that the model should emulate. These
+ * examples are treated identically to conversation messages except that they take
+ * precedence over the history in messages: If the total input size exceeds the
+ * model's inputTokenLimit the input will be truncated. Items will be dropped from
+ * messages before examples.
+ * @param messages (optional) A snapshot of the recent conversation history sorted
+ * chronologically. Turns alternate between two authors. If the total input size
+ * exceeds the model's inputTokenLimit the input will be truncated: The oldest items
+ * will be dropped from messages.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record MessagePrompt(
+ @JsonProperty("context") String context,
+ @JsonProperty("examples") List examples,
+ @JsonProperty("messages") List messages) {
+
+ /**
+ * Shortcut constructor for a message prompt without context.
+ * @param messages The conversation history used by the model.
+ */
+ public MessagePrompt(List messages) {
+ this(null, null, messages);
+ }
+
+ /**
+ * Shortcut constructor for a message prompt without context.
+ * @param context An input/output example used to instruct the Model. It
+ * demonstrates how the model should respond or format its response.
+ * @param messages The conversation history used by the model.
+ */
+ public MessagePrompt(String context, List messages) {
+ this(context, null, messages);
+ }
+
+ /**
+ * An input/output example used to instruct the Model. It demonstrates how the
+ * model should respond or format its response.
+ *
+ * @param input An example of an input Message from the user.
+ * @param output An example of an output Message from the model.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record Example(
+ @JsonProperty("input") Message input,
+ @JsonProperty("output") Message output) {
+ }
+ }
+
+ /**
+ * Message generation request body.
+ *
+ * @param prompt The structured textual input given to the model as a prompt. Given a
+ * prompt, the model will return what it predicts is the next message in the
+ * discussion.
+ * @param temperature (optional) Controls the randomness of the output. Values can
+ * range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that
+ * are more varied, while a value closer to 0.0 will typically result in less
+ * surprising responses from the model.
+ * @param candidateCount (optional) The number of generated response messages to
+ * return. This value must be between [1, 8], inclusive. If unset, this will default
+ * to 1.
+ * @param topP (optional) The maximum cumulative probability of tokens to consider
+ * when sampling. The model uses combined Top-k and nucleus sampling. Nucleus sampling
+ * considers the smallest set of tokens whose probability sum is at least topP.
+ * @param topK (optional) The maximum number of tokens to consider when sampling. The
+ * model uses combined Top-k and nucleus sampling. Top-k sampling considers the set of
+ * topK most probable tokens.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record GenerateMessageRequest(
+ @JsonProperty("prompt") MessagePrompt prompt,
+ @JsonProperty("temperature") Float temperature,
+ @JsonProperty("candidateCount") Integer candidateCount,
+ @JsonProperty("topP") Float topP,
+ @JsonProperty("topK") Integer topK) {
+
+ /**
+ * Shortcut constructor to create a GenerateMessageRequest with only the prompt
+ * parameter.
+ * @param prompt The structured textual input given to the model as a prompt.
+ */
+ public GenerateMessageRequest(MessagePrompt prompt) {
+ this(prompt, null, null, null, null);
+ }
+
+ /**
+ * Shortcut constructor to create a GenerateMessageRequest with only the prompt
+ * and temperature parameters.
+ * @param prompt The structured textual input given to the model as a prompt.
+ * @param temperature (optional) Controls the randomness of the output.
+ * @param topK (optional) The maximum number of tokens to consider when sampling.
+ */
+ public GenerateMessageRequest(MessagePrompt prompt, Float temperature, Integer topK) {
+ this(prompt, temperature, null, null, topK);
+ }
+ }
+
+ /**
+ * The response from the model. This includes candidate messages and conversation
+ * history in the form of chronologically-ordered messages.
+ *
+ * @param candidates Candidate response messages from the model.
+ * @param messages The conversation history used by the model.
+ * @param filters A set of content filtering metadata for the prompt and response
+ * text. This indicates which SafetyCategory(s) blocked a candidate from this
+ * response, the lowest HarmProbability that triggered a block, and the HarmThreshold
+ * setting for that category.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record GenerateMessageResponse(
+ @JsonProperty("candidates") List candidates,
+ @JsonProperty("messages") List messages,
+ @JsonProperty("filters") List filters) {
+
+ /**
+ * Content filtering metadata associated with processing a single request. It
+ * contains a reason and an optional supporting string. The reason may be
+ * unspecified.
+ *
+ * @param reason The reason content was blocked during request processing.
+ * @param message A string that describes the filtering behavior in more detail.
+ */
+ @JsonInclude(Include.NON_NULL)
+ public record ContentFilter(
+ @JsonProperty("reason") BlockedReason reason,
+ @JsonProperty("message") String message) {
+
+ /**
+ * Reasons why content may have been blocked.
+ */
+ public enum BlockedReason {
+
+ /**
+ * A blocked reason was not specified.
+ */
+ BLOCKED_REASON_UNSPECIFIED,
+ /**
+ * Content was blocked by safety settings.
+ */
+ SAFETY,
+ /**
+ * Content was blocked, but the reason is uncategorized.
+ */
+ OTHER
+
+ }
+ }
+ }
+
+ /**
+ * Main method to test the VertexAiApi.
+ * @param args blank.
+ */
+ public static void main(String[] args) {
+ VertexAiApi vertexAiApi = new VertexAiApi(System.getenv("PALM_API_KEY"));
+
+ var prompt = new MessagePrompt(List.of(new Message("0", "Hello, how are you?")));
+
+ GenerateMessageRequest request = new GenerateMessageRequest(prompt);
+
+ GenerateMessageResponse response = vertexAiApi.generateMessage(request);
+
+ System.out.println(response);
+
+ System.out.println(vertexAiApi.embedText("Hello, how are you?"));
+
+ System.out.println(vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")));
+
+ System.out.println(vertexAiApi.countMessageTokens(prompt));
+
+ System.out.println(vertexAiApi.listModels());
+
+ System.out.println(vertexAiApi.listModels().stream().map(vertexAiApi::getModel).toList());
+
+ }
+
+}
+// @formatter:on
\ No newline at end of file
diff --git a/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClient.java b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClient.java
new file mode 100644
index 00000000000..bb18bdf7e19
--- /dev/null
+++ b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClient.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.vertex.embedding;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.Embedding;
+import org.springframework.ai.embedding.EmbeddingClient;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.vertex.api.VertexAiApi;
+
+/**
+ * @author Christian Tzolov
+ */
+public class VertexAiEmbeddingClient implements EmbeddingClient {
+
+ private final VertexAiApi vertexAiApi;
+
+ public VertexAiEmbeddingClient(VertexAiApi vertexAiApi) {
+ this.vertexAiApi = vertexAiApi;
+ }
+
+ @Override
+ public List embed(String text) {
+ return this.vertexAiApi.embedText(text).value();
+ }
+
+ @Override
+ public List embed(Document document) {
+ return embed(document.getContent());
+ }
+
+ @Override
+ public List> embed(List texts) {
+ List vertexEmbeddings = this.vertexAiApi.batchEmbedText(texts);
+ return vertexEmbeddings.stream().map(e -> e.value()).toList();
+ }
+
+ @Override
+ public EmbeddingResponse embedForResponse(List texts) {
+ List vertexEmbeddings = this.vertexAiApi.batchEmbedText(texts);
+ int index = 0;
+ List embeddings = new ArrayList<>();
+ for (VertexAiApi.Embedding vertexEmbedding : vertexEmbeddings) {
+ embeddings.add(new Embedding(vertexEmbedding.value(), index++));
+ }
+ return new EmbeddingResponse(embeddings, Map.of());
+ }
+
+}
diff --git a/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClient.java b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClient.java
new file mode 100644
index 00000000000..b51fe674da6
--- /dev/null
+++ b/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClient.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.vertex.generation;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.springframework.ai.client.AiClient;
+import org.springframework.ai.client.AiResponse;
+import org.springframework.ai.client.Generation;
+import org.springframework.ai.prompt.Prompt;
+import org.springframework.ai.prompt.messages.MessageType;
+import org.springframework.ai.vertex.api.VertexAiApi;
+import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest;
+import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse;
+import org.springframework.ai.vertex.api.VertexAiApi.MessagePrompt;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+
+/**
+ * @author Christian Tzolov
+ */
+public class VertexAiChatGenerationClient implements AiClient {
+
+ private final VertexAiApi vertexAiApi;
+
+ private Float temperature;
+
+ private Float topP;
+
+ private Integer candidateCount;
+
+ private Integer maxTokens;
+
+ public VertexAiChatGenerationClient(VertexAiApi vertexAiApi) {
+ this.vertexAiApi = vertexAiApi;
+ }
+
+ @Override
+ public AiResponse generate(Prompt prompt) {
+
+ String vertexContext = prompt.getMessages()
+ .stream()
+ .filter(m -> m.getMessageType() == MessageType.SYSTEM)
+ .map(m -> m.getContent())
+ .collect(Collectors.joining("\n"));
+
+ List vertexMessages = prompt.getMessages()
+ .stream()
+ .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
+ .map(m -> new VertexAiApi.Message(m.getMessageType().getValue(), m.getContent()))
+ .toList();
+
+ Assert.isTrue(!CollectionUtils.isEmpty(vertexMessages), "No user or assistant messages found in the prompt!");
+
+ var vertexPrompt = new MessagePrompt(vertexContext, vertexMessages);
+
+ GenerateMessageRequest request = new GenerateMessageRequest(vertexPrompt, this.temperature, this.candidateCount,
+ this.topP, this.maxTokens);
+
+ GenerateMessageResponse response = this.vertexAiApi.generateMessage(request);
+
+ List generations = response.candidates()
+ .stream()
+ .map(vmsg -> new Generation(vmsg.content()))
+ .collect(Collectors.toList());
+
+ return new AiResponse(generations);
+ }
+
+}
diff --git a/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiIT.java b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiIT.java
new file mode 100644
index 00000000000..533d810a417
--- /dev/null
+++ b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiIT.java
@@ -0,0 +1,125 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.vertex.api;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import org.springframework.ai.vertex.api.VertexAiApi.Embedding;
+import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest;
+import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse;
+import org.springframework.ai.vertex.api.VertexAiApi.MessagePrompt;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Integration tests for {@link VertexAiApi}. Requires a valid API key to be set via the
+ * {@code PALM_API_KEY} environment and at the moment Google enables is it only in the US
+ * region (so use VPN for testing).
+ *
+ * @author Christian Tzolov
+ */
+@EnabledIfEnvironmentVariable(named = "PALM_API_KEY", matches = ".*")
+public class VertexAiApiIT {
+
+ VertexAiApi vertexAiApi = new VertexAiApi(System.getenv("PALM_API_KEY"));
+
+ @Test
+ public void generateMessage() {
+
+ var prompt = new MessagePrompt(List.of(new VertexAiApi.Message("0", "Hello, how are you?")));
+
+ GenerateMessageRequest request = new GenerateMessageRequest(prompt);
+
+ GenerateMessageResponse response = vertexAiApi.generateMessage(request);
+
+ assertThat(response).isNotNull();
+
+ // Vertex returns the prompt messages in the response's messages list.
+ assertThat(response.messages()).hasSize(1);
+ assertThat(response.messages().get(0)).isEqualTo(prompt.messages().get(0));
+
+ // Vertex returns the answer in the response's candidates list.
+ assertThat(response.candidates()).hasSize(1);
+ assertThat(response.candidates().get(0).author()).isNotBlank();
+ assertThat(response.candidates().get(0).content()).isNotBlank();
+ }
+
+ @Test
+ public void embedText() {
+
+ var text = "Hello, how are you?";
+
+ Embedding response = vertexAiApi.embedText(text);
+
+ assertThat(response).isNotNull();
+ assertThat(response.value()).hasSize(768);
+ }
+
+ @Test
+ public void batchEmbedText() {
+
+ var text = List.of("Hello, how are you?", "I am fine, thank you!");
+
+ List response = vertexAiApi.batchEmbedText(text);
+
+ assertThat(response).isNotNull();
+ assertThat(response).hasSize(2);
+ assertThat(response.get(0).value()).hasSize(768);
+ assertThat(response.get(1).value()).hasSize(768);
+ }
+
+ @Test
+ public void countMessageTokens() {
+
+ var text = "Hello, how are you?";
+
+ var prompt = new MessagePrompt(List.of(new VertexAiApi.Message("0", text)));
+ int response = vertexAiApi.countMessageTokens(prompt);
+
+ assertThat(response).isEqualTo(17);
+ }
+
+ @Test
+ public void listModels() {
+
+ List response = vertexAiApi.listModels();
+
+ assertThat(response).isNotNull();
+ assertThat(response).hasSizeGreaterThan(0);
+ assertThat(response).contains("models/chat-bison-001", "models/text-bison-001", "models/embedding-gecko-001");
+
+ System.out.println(" - " + response.stream()
+ .map(vertexAiApi::getModel)
+ .map(VertexAiApi.Model::toString)
+ .collect(Collectors.joining("\n - ")));
+ }
+
+ @Test
+ public void getModel() {
+
+ VertexAiApi.Model model = vertexAiApi.getModel("models/chat-bison-001");
+
+ System.out.println(model);
+ assertThat(model).isNotNull();
+ assertThat(model.displayName()).isEqualTo("Chat Bison");
+ }
+
+}
diff --git a/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java
new file mode 100644
index 00000000000..78645d27a2a
--- /dev/null
+++ b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java
@@ -0,0 +1,150 @@
+/*
+ * Copyright 2023-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.vertex.api;
+
+import java.util.List;
+import java.util.Map;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.ai.vertex.api.VertexAiApi.Embedding;
+import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest;
+import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse;
+import org.springframework.ai.vertex.api.VertexAiApi.MessagePrompt;
+import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse.ContentFilter.BlockedReason;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.boot.test.autoconfigure.web.client.RestClientTest;
+import org.springframework.context.annotation.Bean;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
+import org.springframework.test.web.client.MockRestServiceServer;
+import org.springframework.web.client.RestClient;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.springframework.test.web.client.match.MockRestRequestMatchers.content;
+import static org.springframework.test.web.client.match.MockRestRequestMatchers.method;
+import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestToUriTemplate;
+import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess;
+
+/**
+ * @author Christian Tzolov
+ */
+@RestClientTest(VertexAiApiTests.Config.class)
+public class VertexAiApiTests {
+
+ private final static String TEST_API_KEY = "test-api-key";
+
+ @Autowired
+ private VertexAiApi client;
+
+ @Autowired
+ private MockRestServiceServer server;
+
+ @Autowired
+ private ObjectMapper objectMapper;
+
+ @AfterEach
+ void resetMockServer() {
+ server.reset();
+ }
+
+ @Test
+ public void generateMessage() throws JsonProcessingException {
+
+ GenerateMessageRequest request = new GenerateMessageRequest(
+ new MessagePrompt(List.of(new VertexAiApi.Message("0", "Hello, how are you?"))));
+
+ GenerateMessageResponse expectedResponse = new GenerateMessageResponse(
+ List.of(new VertexAiApi.Message("1", "Hello, how are you?")),
+ List.of(new VertexAiApi.Message("0", "I'm fine, thank you.")),
+ List.of(new VertexAiApi.GenerateMessageResponse.ContentFilter(BlockedReason.SAFETY, "reason")));
+
+ server
+ .expect(requestToUriTemplate("/models/{model}:generateMessage?key={apiKey}",
+ VertexAiApi.DEFAULT_GENERATE_MODEL, TEST_API_KEY))
+ .andExpect(method(HttpMethod.POST))
+ .andExpect(content().json(objectMapper.writeValueAsString(request)))
+ .andRespond(withSuccess(objectMapper.writeValueAsString(expectedResponse), MediaType.APPLICATION_JSON));
+
+ GenerateMessageResponse response = client.generateMessage(request);
+
+ assertThat(response).isEqualTo(expectedResponse);
+
+ server.verify();
+ }
+
+ @Test
+ public void embedText() throws JsonProcessingException {
+
+ String text = "Hello, how are you?";
+
+ Embedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));
+
+ server
+ .expect(requestToUriTemplate("/models/{model}:embedText?key={apiKey}", VertexAiApi.DEFAULT_EMBEDDING_MODEL,
+ TEST_API_KEY))
+ .andExpect(method(HttpMethod.POST))
+ .andExpect(content().json(objectMapper.writeValueAsString(Map.of("text", text))))
+ .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embedding", expectedEmbedding)),
+ MediaType.APPLICATION_JSON));
+
+ Embedding embedding = client.embedText(text);
+
+ assertThat(embedding).isEqualTo(expectedEmbedding);
+
+ server.verify();
+ }
+
+ @Test
+ public void batchEmbedText() throws JsonProcessingException {
+
+ List texts = List.of("Hello, how are you?", "I'm fine, thank you.");
+
+ List expectedEmbeddings = List.of(new Embedding(List.of(0.1, 0.2, 0.3)),
+ new Embedding(List.of(0.4, 0.5, 0.6)));
+
+ server
+ .expect(requestToUriTemplate("/models/{model}:batchEmbedText?key={apiKey}",
+ VertexAiApi.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))
+ .andExpect(method(HttpMethod.POST))
+ .andExpect(content().json(objectMapper.writeValueAsString(Map.of("texts", texts))))
+ .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embeddings", expectedEmbeddings)),
+ MediaType.APPLICATION_JSON));
+
+ List embeddings = client.batchEmbedText(texts);
+
+ assertThat(embeddings).isEqualTo(expectedEmbeddings);
+
+ server.verify();
+ }
+
+ @SpringBootConfiguration
+ static class Config {
+
+ @Bean
+ public VertexAiApi audioApi(RestClient.Builder builder) {
+ return new VertexAiApi("", TEST_API_KEY, VertexAiApi.DEFAULT_GENERATE_MODEL,
+ VertexAiApi.DEFAULT_EMBEDDING_MODEL, builder);
+ }
+
+ }
+
+}
diff --git a/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java
new file mode 100644
index 00000000000..78ebbfc308a
--- /dev/null
+++ b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java
@@ -0,0 +1,48 @@
+package org.springframework.ai.vertex.embedding;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.vertex.api.VertexAiApi;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.context.annotation.Bean;
+
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+@SpringBootTest
+@EnabledIfEnvironmentVariable(named = "PALM_API_KEY", matches = ".*")
+class VertexAiEmbeddingClientIT {
+
+ @Autowired
+ private VertexAiEmbeddingClient embeddingClient;
+
+ @Test
+ void simpleEmbedding() {
+ assertThat(embeddingClient).isNotNull();
+ EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
+ assertThat(embeddingResponse.getData()).hasSize(1);
+ assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
+ assertThat(embeddingClient.dimensions()).isEqualTo(768);
+ }
+
+ @SpringBootConfiguration
+ public static class TestConfiguration {
+
+ @Bean
+ public VertexAiApi vertexAiApi() {
+ return new VertexAiApi(System.getenv("PALM_API_KEY"));
+ }
+
+ @Bean
+ public VertexAiEmbeddingClient vertexAiEmbedding(VertexAiApi vertexAiApi) {
+ return new VertexAiEmbeddingClient(vertexAiApi);
+ }
+
+ }
+
+}
diff --git a/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java
new file mode 100644
index 00000000000..9c5a610067d
--- /dev/null
+++ b/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java
@@ -0,0 +1,130 @@
+package org.springframework.ai.vertex.generation;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import org.springframework.ai.client.AiResponse;
+import org.springframework.ai.client.Generation;
+import org.springframework.ai.parser.BeanOutputParser;
+import org.springframework.ai.parser.ListOutputParser;
+import org.springframework.ai.parser.MapOutputParser;
+import org.springframework.ai.prompt.Prompt;
+import org.springframework.ai.prompt.PromptTemplate;
+import org.springframework.ai.prompt.SystemPromptTemplate;
+import org.springframework.ai.prompt.messages.Message;
+import org.springframework.ai.prompt.messages.UserMessage;
+import org.springframework.ai.vertex.api.VertexAiApi;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.context.annotation.Bean;
+import org.springframework.core.convert.support.DefaultConversionService;
+import org.springframework.core.io.Resource;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+@SpringBootTest
+@EnabledIfEnvironmentVariable(named = "PALM_API_KEY", matches = ".*")
+class VertexAiChatGenerationClientIT {
+
+ @Autowired
+ private VertexAiChatGenerationClient client;
+
+ @Value("classpath:/prompts/system-message.st")
+ private Resource systemResource;
+
+ @Test
+ void roleTest() {
+ String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
+ String name = "Bob";
+ String voice = "pirate";
+ UserMessage userMessage = new UserMessage(request);
+ SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
+ Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
+ Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
+ AiResponse response = client.generate(prompt);
+ assertThat(response.getGeneration().getContent()).contains("Bartholomew");
+ }
+
+ // @Test
+ void outputParser() {
+ DefaultConversionService conversionService = new DefaultConversionService();
+ ListOutputParser outputParser = new ListOutputParser(conversionService);
+
+ String format = outputParser.getFormat();
+ String template = """
+ List five {subject}
+ {format}
+ """;
+ PromptTemplate promptTemplate = new PromptTemplate(template,
+ Map.of("subject", "ice cream flavors.", "format", format));
+ Prompt prompt = new Prompt(promptTemplate.createMessage());
+ Generation generation = this.client.generate(prompt).getGeneration();
+
+ List list = outputParser.parse(generation.getContent());
+ assertThat(list).hasSize(5);
+
+ }
+
+ // @Test
+ void mapOutputParser() {
+ MapOutputParser outputParser = new MapOutputParser();
+
+ String format = outputParser.getFormat();
+ String template = """
+ Provide me a List of {subject}
+ {format}
+ """;
+ PromptTemplate promptTemplate = new PromptTemplate(template,
+ Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
+ Prompt prompt = new Prompt(promptTemplate.createMessage());
+ Generation generation = client.generate(prompt).getGeneration();
+
+ Map result = outputParser.parse(generation.getContent());
+ assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
+
+ }
+
+ record ActorsFilmsRecord(String actor, List movies) {
+ }
+
+ // @Test
+ void beanOutputParserRecords() {
+
+ BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class);
+
+ String format = outputParser.getFormat();
+ String template = """
+ Generate the filmography of 5 movies for Tom Hanks.
+ {format}
+ """;
+ PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
+ Prompt prompt = new Prompt(promptTemplate.createMessage());
+ Generation generation = client.generate(prompt).getGeneration();
+
+ ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent());
+ assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
+ assertThat(actorsFilms.movies()).hasSize(5);
+ }
+
+ @SpringBootConfiguration
+ public static class TestConfiguration {
+
+ @Bean
+ public VertexAiApi vertexAiApi() {
+ return new VertexAiApi(System.getenv("PALM_API_KEY"));
+ }
+
+ @Bean
+ public VertexAiChatGenerationClient vertexAiEmbedding(VertexAiApi vertexAiApi) {
+ return new VertexAiChatGenerationClient(vertexAiApi);
+ }
+
+ }
+
+}
diff --git a/spring-ai-vertex-ai/src/test/resources/prompts/system-message.st b/spring-ai-vertex-ai/src/test/resources/prompts/system-message.st
new file mode 100644
index 00000000000..dc2cf2dcd84
--- /dev/null
+++ b/spring-ai-vertex-ai/src/test/resources/prompts/system-message.st
@@ -0,0 +1,4 @@
+"You are a helpful AI assistant. Your name is {name}.
+You are an AI assistant that helps people find information.
+Your name is {name}
+You should reply to the user's request with your name and also in the style of a {voice}.
\ No newline at end of file