From d1710711726ff47cbde0d6b20766593842405805 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Sun, 29 Sep 2024 17:34:59 -0400 Subject: [PATCH 1/2] Add retry support to VertexAI embedding and chat models Introduces retry functionality to VertexAI embedding and chat models, enhancing their resilience against transient failures. It also corrects a typo in the VertexAiEmbeddingConnectionDetails class name. Key changes: * Add RetryTemplate to VertexAiTextEmbeddingModel and VertexAiGeminiChatModel * Introduce spring-ai-retry dependency * Refactor code to support retry logic * Update auto-configuration classes to incorporate retry functionality * Fix typo in VertexAiEmbeddingConnectionDetails class name --- models/spring-ai-vertex-ai-embedding/pom.xml | 6 + ...> VertexAiEmbeddingConnectionDetails.java} | 13 +- .../VertexAiMultimodalEmbeddingModel.java | 6 +- .../text/VertexAiTextEmbeddingModel.java | 113 +++++++++----- .../VertexAiMultimodalEmbeddingModelIT.java | 8 +- .../text/TestVertexAiTextEmbeddingModel.java | 57 +++++++ .../text/VertexAiTextEmbeddingModelIT.java | 8 +- .../text/VertexAiTextEmbeddingRetryTests.java | 143 ++++++++++++++++++ models/spring-ai-vertex-ai-gemini/pom.xml | 6 + .../gemini/VertexAiGeminiChatModel.java | 105 +++++++------ .../gemini/TestVertexAiGeminiChatModel.java | 45 ++++++ .../gemini/VertexAiGeminiRetryTests.java | 120 +++++++++++++++ .../VertexAiEmbeddingAutoConfiguration.java | 22 ++- .../VertexAiGeminiAutoConfiguration.java | 12 +- 14 files changed, 549 insertions(+), 115 deletions(-) rename models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/{VertexAiEmbeddigConnectionDetails.java => VertexAiEmbeddingConnectionDetails.java} (89%) create mode 100644 models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java create mode 100644 models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java create mode 100644 models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java create mode 100644 models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java diff --git a/models/spring-ai-vertex-ai-embedding/pom.xml b/models/spring-ai-vertex-ai-embedding/pom.xml index 8fdbc8afc08..08526534078 100644 --- a/models/spring-ai-vertex-ai-embedding/pom.xml +++ b/models/spring-ai-vertex-ai-embedding/pom.xml @@ -52,6 +52,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + org.springframework spring-web diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java similarity index 89% rename from models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java rename to models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java index a4653f248e9..e119cd31410 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java @@ -23,11 +23,14 @@ import com.google.cloud.aiplatform.v1.PredictionServiceSettings; /** - * VertexAiEmbeddigConnectionDetails represents the details of a connection to the Vertex + * VertexAiEmbeddingConnectionDetails represents the details of a connection to the Vertex * AI embedding service. It provides methods to access the project ID, location, * publisher, and PredictionServiceSettings. + * + * @author Christian Tzolov + * @since 1.0.0 */ -public class VertexAiEmbeddigConnectionDetails { +public class VertexAiEmbeddingConnectionDetails { private static final String DEFAULT_LOCATION = "us-central1"; @@ -55,7 +58,7 @@ public class VertexAiEmbeddigConnectionDetails { private final String publisher; - public VertexAiEmbeddigConnectionDetails(String endpoint, String projectId, String location, String publisher) { + public VertexAiEmbeddingConnectionDetails(String endpoint, String projectId, String location, String publisher) { this.projectId = projectId; this.location = location; this.publisher = publisher; @@ -119,7 +122,7 @@ public Builder withPublisher(String publisher) { return this; } - public VertexAiEmbeddigConnectionDetails build() { + public VertexAiEmbeddingConnectionDetails build() { if (!StringUtils.hasText(this.endpoint)) { if (!StringUtils.hasText(this.location)) { this.endpoint = DEFAULT_ENDPOINT; @@ -134,7 +137,7 @@ public VertexAiEmbeddigConnectionDetails build() { this.publisher = DEFAULT_PUBLISHER; } - return new VertexAiEmbeddigConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher); + return new VertexAiEmbeddingConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher); } } diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java index 4d3a8066eb7..4d217509d13 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java @@ -35,7 +35,7 @@ import org.springframework.ai.embedding.EmbeddingResultMetadata; import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder; @@ -76,9 +76,9 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel private static final List SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG, MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp")); - private final VertexAiEmbeddigConnectionDetails connectionDetails; + private final VertexAiEmbeddingConnectionDetails connectionDetails; - public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails, + public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiMultimodalEmbeddingOptions defaultEmbeddingOptions) { Assert.notNull(defaultEmbeddingOptions, "VertexAiMultimodalEmbeddingOptions must not be null"); diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 7f2654db553..fdbf677188c 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -29,14 +29,17 @@ import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder; -import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -53,16 +56,22 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { public final VertexAiTextEmbeddingOptions defaultOptions; - private final VertexAiEmbeddigConnectionDetails connectionDetails; + private final VertexAiEmbeddingConnectionDetails connectionDetails; + + private final RetryTemplate retryTemplate; - public VertexAiTextEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails, + public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiTextEmbeddingOptions defaultEmbeddingOptions) { + this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, + VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) { Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null"); - + Assert.notNull(retryTemplate, "retryTemplate must not be null"); this.defaultOptions = defaultEmbeddingOptions.initializeDefaults(); - this.connectionDetails = connectionDetails; + this.retryTemplate = retryTemplate; } @Override @@ -73,46 +82,23 @@ public float[] embed(Document document) { @Override public EmbeddingResponse call(EmbeddingRequest request) { + return retryTemplate.execute(context -> { + VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions; - VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions; - - if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) { - var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build(); - finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy, - VertexAiTextEmbeddingOptions.class); - } - - try (PredictionServiceClient client = PredictionServiceClient - .create(this.connectionDetails.getPredictionServiceSettings())) { - - EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); - - PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder() - .setEndpoint(endpointName.toString()); - - TextParametersBuilder parametersBuilder = TextParametersBuilder.of(); - - if (finalOptions.getAutoTruncate() != null) { - parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate()); - } - - if (finalOptions.getDimensions() != null) { - parametersBuilder.withOutputDimensionality(finalOptions.getDimensions()); + if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) { + var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build(); + finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy, + VertexAiTextEmbeddingOptions.class); } - predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build())); + PredictionServiceClient client = createPredictionServiceClient(); - for (int i = 0; i < request.getInstructions().size(); i++) { + EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); - TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i)) - .withTaskType(finalOptions.getTaskType().name()); - if (StringUtils.hasText(finalOptions.getTitle())) { - instanceBuilder.withTitle(finalOptions.getTitle()); - } - predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build())); - } + PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, + finalOptions); - PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build()); + PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder); int index = 0; int totalTokenCount = 0; @@ -131,12 +117,57 @@ public EmbeddingResponse call(EmbeddingRequest request) { } return new EmbeddingResponse(embeddingList, generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); + // } + // catch (Exception e) { + // throw new RuntimeException(e); + // } + }); + } + + protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName, + VertexAiTextEmbeddingOptions finalOptions) { + PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString()); + + TextParametersBuilder parametersBuilder = TextParametersBuilder.of(); + + if (finalOptions.getAutoTruncate() != null) { + parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate()); } - catch (Exception e) { + + if (finalOptions.getDimensions() != null) { + parametersBuilder.withOutputDimensionality(finalOptions.getDimensions()); + } + + predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build())); + + for (int i = 0; i < request.getInstructions().size(); i++) { + + TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i)) + .withTaskType(finalOptions.getTaskType().name()); + if (StringUtils.hasText(finalOptions.getTitle())) { + instanceBuilder.withTitle(finalOptions.getTitle()); + } + predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build())); + } + return predictRequestBuilder; + } + + // for testing + PredictionServiceClient createPredictionServiceClient() { + try { + return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings()); + } + catch (IOException e) { throw new RuntimeException(e); } } + // for testing + PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { + PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build()); + return embeddingResponse; + } + private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) { EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); metadata.setModel(model); diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java index 7f3f1a61639..6caf874323e 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java @@ -24,7 +24,7 @@ import org.springframework.ai.embedding.DocumentEmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResultMetadata; -import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -213,8 +213,8 @@ void textImageAndVideoEmbedding() { static class Config { @Bean - public VertexAiEmbeddigConnectionDetails connectionDetails() { - return VertexAiEmbeddigConnectionDetails.builder() + public VertexAiEmbeddingConnectionDetails connectionDetails() { + return VertexAiEmbeddingConnectionDetails.builder() .withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID")) .withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION")) .build(); @@ -222,7 +222,7 @@ public VertexAiEmbeddigConnectionDetails connectionDetails() { @Bean public VertexAiMultimodalEmbeddingModel vertexAiEmbeddingModel( - VertexAiEmbeddigConnectionDetails connectionDetails) { + VertexAiEmbeddingConnectionDetails connectionDetails) { VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions.builder() .withModel(VertexAiMultimodalEmbeddingModelName.MULTIMODAL_EMBEDDING_001) diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java new file mode 100644 index 00000000000..3ba3d878a0a --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java @@ -0,0 +1,57 @@ +package org.springframework.ai.vertexai.embedding.text; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; +import org.springframework.retry.support.RetryTemplate; + +import java.io.IOException; + +public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel { + + private PredictionServiceClient mockPredictionServiceClient; + + private PredictRequest.Builder mockPredictRequestBuilder; + + public TestVertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, + VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) { + super(connectionDetails, defaultEmbeddingOptions, retryTemplate); + } + + public void setMockPredictionServiceClient(PredictionServiceClient mockPredictionServiceClient) { + this.mockPredictionServiceClient = mockPredictionServiceClient; + } + + @Override + PredictionServiceClient createPredictionServiceClient() { + if (mockPredictionServiceClient != null) { + return mockPredictionServiceClient; + } + return super.createPredictionServiceClient(); + } + + @Override + PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { + if (mockPredictionServiceClient != null) { + return mockPredictionServiceClient.predict(predictRequestBuilder.build()); + } + return super.getPredictResponse(client, predictRequestBuilder); + } + + public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictRequestBuilder) { + this.mockPredictRequestBuilder = mockPredictRequestBuilder; + } + + @Override + protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName, + VertexAiTextEmbeddingOptions finalOptions) { + if (mockPredictRequestBuilder != null) { + return mockPredictRequestBuilder; + } + return super.getPredictRequestBuilder(request, endpointName, finalOptions); + } + +} \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java index 4c9a9cdcb0f..f98c96b1baa 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java @@ -24,7 +24,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -67,15 +67,15 @@ void defaultEmbedding(String modelName) { static class Config { @Bean - public VertexAiEmbeddigConnectionDetails connectionDetails() { - return VertexAiEmbeddigConnectionDetails.builder() + public VertexAiEmbeddingConnectionDetails connectionDetails() { + return VertexAiEmbeddingConnectionDetails.builder() .withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID")) .withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION")) .build(); } @Bean - public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails) { + public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails) { VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() .withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME) diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java new file mode 100644 index 00000000000..2f8658c1d14 --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java @@ -0,0 +1,143 @@ +package org.springframework.ai.vertexai.embedding.text; + +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class VertexAiTextEmbeddingRetryTests { + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + onErrorRetryCount = context.getRetryCount(); + } + + } + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + @Mock + private PredictionServiceClient mockPredictionServiceClient; + + @Mock + private VertexAiEmbeddingConnectionDetails mockConnectionDetails; + + @Mock + private PredictRequest.Builder mockPredictRequestBuilder; + + @Mock + private PredictionServiceSettings mockPredictionServiceSettings; + + private TestVertexAiTextEmbeddingModel embeddingModel; + + @BeforeEach + public void setUp() { + retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + retryListener = new TestRetryListener(); + retryTemplate.registerListener(retryListener); + + embeddingModel = new TestVertexAiTextEmbeddingModel(mockConnectionDetails, + VertexAiTextEmbeddingOptions.builder().build(), retryTemplate); + embeddingModel.setMockPredictionServiceClient(mockPredictionServiceClient); + embeddingModel.setMockPredictRequestBuilder(mockPredictRequestBuilder); + when(mockPredictRequestBuilder.build()).thenReturn(PredictRequest.getDefaultInstance()); + } + + @Test + public void vertexAiEmbeddingTransientError() { + // Setup the mock PredictResponse + PredictResponse mockResponse = PredictResponse.newBuilder() + .addPredictions(Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("embeddings", Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("values", + Value.newBuilder() + .setListValue(com.google.protobuf.ListValue.newBuilder() + .addValues(Value.newBuilder().setNumberValue(9.9)) + .addValues(Value.newBuilder().setNumberValue(8.8)) + .build()) + .build()) + .putFields("statistics", + Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("token_count", Value.newBuilder().setNumberValue(10).build()) + .build()) + .build()) + .build()) + .build()) + .build()) + .build()) + .build(); + + // Setup the mock PredictionServiceClient + when(mockPredictionServiceClient.predict(any())).thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(mockResponse); + + EmbeddingResponse result = embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null)); + + assertThat(result).isNotNull(); + assertThat(result.getResults()).hasSize(1); + assertThat(result.getResults().get(0).getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + + verify(mockPredictRequestBuilder, times(3)).build(); + } + + @Test + public void vertexAiEmbeddingNonTransientError() { + // Setup the mock PredictionServiceClient to throw a non-transient error + when(mockPredictionServiceClient.predict(any())) + .thenThrow(new RuntimeException("Non Transient Error")); + + // Assert that a RuntimeException is thrown and not retried + assertThrows(RuntimeException.class, () -> embeddingModel + .call(new EmbeddingRequest(List.of("text1", "text2"), null))); + + // Verify that predict was called only once (no retries for non-transient errors) + verify(mockPredictionServiceClient, times(1)).predict(any()); + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/pom.xml b/models/spring-ai-vertex-ai-gemini/pom.xml index 3ebd175128f..c903c8dc0be 100644 --- a/models/spring-ai-vertex-ai-gemini/pom.xml +++ b/models/spring-ai-vertex-ai-gemini/pom.xml @@ -52,6 +52,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + org.springframework spring-web diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index b7ea5051d3d..ece66fc31c3 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -15,14 +15,16 @@ */ package org.springframework.ai.vertexai.gemini; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import com.google.cloud.vertexai.api.GoogleSearchRetrieval; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.google.cloud.vertexai.VertexAI; +import com.google.cloud.vertexai.api.*; +import com.google.cloud.vertexai.api.Candidate.FinishReason; +import com.google.cloud.vertexai.generativeai.GenerativeModel; +import com.google.cloud.vertexai.generativeai.PartMaker; +import com.google.cloud.vertexai.generativeai.ResponseStream; +import com.google.protobuf.Struct; +import com.google.protobuf.util.JsonFormat; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -42,35 +44,23 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage; import org.springframework.beans.factory.DisposableBean; import org.springframework.lang.NonNull; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.google.cloud.vertexai.VertexAI; -import com.google.cloud.vertexai.api.Candidate; -import com.google.cloud.vertexai.api.Candidate.FinishReason; -import com.google.cloud.vertexai.api.Content; -import com.google.cloud.vertexai.api.FunctionCall; -import com.google.cloud.vertexai.api.FunctionDeclaration; -import com.google.cloud.vertexai.api.FunctionResponse; -import com.google.cloud.vertexai.api.GenerateContentResponse; -import com.google.cloud.vertexai.api.GenerationConfig; -import com.google.cloud.vertexai.api.Part; -import com.google.cloud.vertexai.api.Schema; -import com.google.cloud.vertexai.api.Tool; -import com.google.cloud.vertexai.generativeai.GenerativeModel; -import com.google.cloud.vertexai.generativeai.PartMaker; -import com.google.cloud.vertexai.generativeai.ResponseStream; -import com.google.protobuf.Struct; -import com.google.protobuf.util.JsonFormat; - import reactor.core.publisher.Flux; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + /** * @author Christian Tzolov * @author Grogdunn @@ -86,6 +76,11 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements private final VertexAiGeminiChatOptions defaultOptions; + /** + * The retry template used to retry the API calls. + */ + private final RetryTemplate retryTemplate; + private final GenerationConfig generationConfig; public enum GeminiMessageType { @@ -152,43 +147,52 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions opti public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks) { + this(vertexAI, options, functionCallbackContext, toolFunctionCallbacks, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, + FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks, + RetryTemplate retryTemplate) { super(functionCallbackContext, options, toolFunctionCallbacks); Assert.notNull(vertexAI, "VertexAI must not be null"); Assert.notNull(options, "VertexAiGeminiChatOptions must not be null"); Assert.notNull(options.getModel(), "VertexAiGeminiChatOptions.modelName must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); this.vertexAI = vertexAI; this.defaultOptions = options; this.generationConfig = toGenerationConfig(options); + this.retryTemplate = retryTemplate; } // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini @Override public ChatResponse call(Prompt prompt) { + return retryTemplate.execute(context -> { + var geminiRequest = createGeminiRequest(prompt); - var geminiRequest = createGeminiRequest(prompt); + GenerateContentResponse response = this.getContentResponse(geminiRequest); - GenerateContentResponse response = this.getContentResponse(geminiRequest); - - List generations = response.getCandidatesList() - .stream() - .map(this::responseCandiateToGeneration) - .flatMap(List::stream) - .toList(); + List generations = response.getCandidatesList() + .stream() + .map(this::responseCandiateToGeneration) + .flatMap(List::stream) + .toList(); - ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response)); + ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response)); - if (!isProxyToolCalls(prompt, this.defaultOptions) - && isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, chatResponse); - // Recursively call the call method with the tool call message - // conversation that contains the call responses. - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); - } + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, chatResponse); + // Recursively call the call method with the tool call message + // conversation that contains the call responses. + return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + } - return chatResponse; + return chatResponse; + }); } @Override @@ -525,7 +529,14 @@ private static Schema jsonToSchema(String json) { } } - private GenerateContentResponse getContentResponse(GeminiRequest request) { + /** + * Generates the content response based on the provided Gemini request. Package + * protected for testing purposes. + * @param request the GeminiRequest containing the content and model information + * @return a GenerateContentResponse containing the generated content + * @throws RuntimeException if content generation fails + */ + GenerateContentResponse getContentResponse(GeminiRequest request) { try { return request.model.generateContent(request.contents); } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java new file mode 100644 index 00000000000..5c3fcffdbbb --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java @@ -0,0 +1,45 @@ +package org.springframework.ai.vertexai.gemini; + +import com.google.cloud.vertexai.VertexAI; +import com.google.cloud.vertexai.api.GenerateContentResponse; +import com.google.cloud.vertexai.generativeai.GenerativeModel; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.retry.support.RetryTemplate; + +import java.io.IOException; +import java.util.List; + +public class TestVertexAiGeminiChatModel extends VertexAiGeminiChatModel { + + private GenerativeModel mockGenerativeModel; + + public TestVertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, + FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks, + RetryTemplate retryTemplate) { + super(vertexAI, options, functionCallbackContext, toolFunctionCallbacks, retryTemplate); + } + + @Override + GenerateContentResponse getContentResponse(GeminiRequest request) { + if (mockGenerativeModel != null) { + try { + return mockGenerativeModel.generateContent(request.contents()); + } + catch (IOException e) { + // Should not be thrown by testing class + throw new RuntimeException("Failed to generate content", e); + } + catch (RuntimeException e) { + // Re-throw RuntimeExceptions (including TransientAiException) as is + throw e; + } + } + return super.getContentResponse(request); + } + + public void setMockGenerativeModel(GenerativeModel mockGenerativeModel) { + this.mockGenerativeModel = mockGenerativeModel; + } + +} \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java new file mode 100644 index 00000000000..924342384a3 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java @@ -0,0 +1,120 @@ +package org.springframework.ai.vertexai.gemini; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.*; + +import com.google.cloud.vertexai.VertexAI; +import com.google.cloud.vertexai.api.Candidate; +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.GenerateContentResponse; +import com.google.cloud.vertexai.api.Part; +import com.google.cloud.vertexai.generativeai.GenerativeModel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@SuppressWarnings("unchecked") +@ExtendWith(MockitoExtension.class) +public class VertexAiGeminiRetryTests { + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + onErrorRetryCount = context.getRetryCount(); + } + + } + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + @Mock + private VertexAI vertexAI; + + @Mock + private GenerativeModel mockGenerativeModel; + + private TestVertexAiGeminiChatModel chatModel; + + @BeforeEach + public void setUp() { + retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + retryListener = new TestRetryListener(); + retryTemplate.registerListener(retryListener); + + chatModel = new TestVertexAiGeminiChatModel(vertexAI, + VertexAiGeminiChatOptions.builder() + .withTemperature(0.7) + .withTopP(1.0) + .withModel(VertexAiGeminiChatModel.ChatModel.GEMINI_PRO.getValue()) + .build(), + null, Collections.emptyList(), retryTemplate); + + chatModel.setMockGenerativeModel(mockGenerativeModel); + } + + @Test + public void vertexAiGeminiChatTransientError() throws IOException { + // Create a mocked successful response + GenerateContentResponse mockedResponse = GenerateContentResponse.newBuilder() + .addCandidates(Candidate.newBuilder() + .setContent(Content.newBuilder().addParts(Part.newBuilder().setText("Response").build()).build()) + .build()) + .build(); + + when(mockGenerativeModel.generateContent(any(List.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(mockedResponse); + + // Call the chat model + ChatResponse result = chatModel.call(new Prompt("test prompt")); + + // Assertions + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()).isEqualTo("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void vertexAiGeminiChatNonTransientError() throws Exception { + // Set up the mock GenerativeModel to throw a non-transient RuntimeException + when(mockGenerativeModel.generateContent(any(List.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + + // Assert that a RuntimeException is thrown when calling the chat model + assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("test prompt"))); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java index 16bde0c6ce6..668975879f3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java @@ -17,14 +17,18 @@ import java.io.IOException; -import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModel; import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -36,20 +40,22 @@ * @author Christian Tzolov * @since 1.0.0 */ +@AutoConfiguration(after = { SpringAiRetryAutoConfiguration.class }) @ConditionalOnClass({ VertexAI.class, VertexAiTextEmbeddingModel.class }) @EnableConfigurationProperties({ VertexAiEmbeddingConnectionProperties.class, VertexAiTextEmbeddingProperties.class, - VertexAiMultimodalEmbeddingProperties.class, }) + VertexAiMultimodalEmbeddingProperties.class }) +@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class }) public class VertexAiEmbeddingAutoConfiguration { @Bean @ConditionalOnMissingBean - public VertexAiEmbeddigConnectionDetails connectionDetails( + public VertexAiEmbeddingConnectionDetails connectionDetails( VertexAiEmbeddingConnectionProperties connectionProperties) { Assert.hasText(connectionProperties.getProjectId(), "Vertex AI project-id must be set!"); Assert.hasText(connectionProperties.getLocation(), "Vertex AI location must be set!"); - var connectionBuilder = VertexAiEmbeddigConnectionDetails.builder() + var connectionBuilder = VertexAiEmbeddingConnectionDetails.builder() .withProjectId(connectionProperties.getProjectId()) .withLocation(connectionProperties.getLocation()); @@ -65,17 +71,17 @@ public VertexAiEmbeddigConnectionDetails connectionDetails( @ConditionalOnMissingBean @ConditionalOnProperty(prefix = VertexAiTextEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddigConnectionDetails connectionDetails, - VertexAiTextEmbeddingProperties textEmbeddingProperties) throws IOException { + public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddingConnectionDetails connectionDetails, + VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate) { - return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions()); + return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(), retryTemplate); } @Bean @ConditionalOnMissingBean @ConditionalOnProperty(prefix = VertexAiMultimodalEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public VertexAiMultimodalEmbeddingModel multimodalEmbedding(VertexAiEmbeddigConnectionDetails connectionDetails, + public VertexAiMultimodalEmbeddingModel multimodalEmbedding(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiMultimodalEmbeddingProperties multimodalEmbeddingProperties) throws IOException { return new VertexAiMultimodalEmbeddingModel(connectionDetails, multimodalEmbeddingProperties.getOptions()); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java index 6003594f7bf..17dc2121a87 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java @@ -18,16 +18,20 @@ import java.io.IOException; import java.util.List; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -40,10 +44,12 @@ * * @author Christian Tzolov * @author Soby Chacko - * @since 0.8.0 + * @since 1.0.0 */ +@AutoConfiguration(after = { SpringAiRetryAutoConfiguration.class }) @ConditionalOnClass({ VertexAI.class, VertexAiGeminiChatModel.class }) @EnableConfigurationProperties({ VertexAiGeminiChatProperties.class, VertexAiGeminiConnectionProperties.class }) +@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class }) public class VertexAiGeminiAutoConfiguration { @Bean @@ -79,12 +85,12 @@ public VertexAI vertexAi(VertexAiGeminiConnectionProperties connectionProperties @ConditionalOnProperty(prefix = VertexAiGeminiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public VertexAiGeminiChatModel vertexAiGeminiChat(VertexAI vertexAi, VertexAiGeminiChatProperties chatProperties, - List toolFunctionCallbacks, ApplicationContext context) { + List toolFunctionCallbacks, ApplicationContext context, RetryTemplate retryTemplate) { FunctionCallbackContext functionCallbackContext = springAiFunctionManager(context); return new VertexAiGeminiChatModel(vertexAi, chatProperties.getOptions(), functionCallbackContext, - toolFunctionCallbacks); + toolFunctionCallbacks, retryTemplate); } /** From 3739814bc16f52b442c478c5d462dce54f2d2cef Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Sun, 29 Sep 2024 17:46:37 -0400 Subject: [PATCH 2/2] remove extraneous commented out code --- .../vertexai/embedding/text/VertexAiTextEmbeddingModel.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index fdbf677188c..0439eb87995 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -117,10 +117,6 @@ public EmbeddingResponse call(EmbeddingRequest request) { } return new EmbeddingResponse(embeddingList, generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); - // } - // catch (Exception e) { - // throw new RuntimeException(e); - // } }); }