diff --git a/models/spring-ai-vertex-ai-embedding/pom.xml b/models/spring-ai-vertex-ai-embedding/pom.xml
index 8fdbc8afc08..08526534078 100644
--- a/models/spring-ai-vertex-ai-embedding/pom.xml
+++ b/models/spring-ai-vertex-ai-embedding/pom.xml
@@ -52,6 +52,12 @@
${project.parent.version}
+
+ org.springframework.ai
+ spring-ai-retry
+ ${project.parent.version}
+
+
org.springframework
spring-web
diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java
similarity index 89%
rename from models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java
rename to models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java
index a4653f248e9..e119cd31410 100644
--- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java
+++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java
@@ -23,11 +23,14 @@
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
/**
- * VertexAiEmbeddigConnectionDetails represents the details of a connection to the Vertex
+ * VertexAiEmbeddingConnectionDetails represents the details of a connection to the Vertex
* AI embedding service. It provides methods to access the project ID, location,
* publisher, and PredictionServiceSettings.
+ *
+ * @author Christian Tzolov
+ * @since 1.0.0
*/
-public class VertexAiEmbeddigConnectionDetails {
+public class VertexAiEmbeddingConnectionDetails {
private static final String DEFAULT_LOCATION = "us-central1";
@@ -55,7 +58,7 @@ public class VertexAiEmbeddigConnectionDetails {
private final String publisher;
- public VertexAiEmbeddigConnectionDetails(String endpoint, String projectId, String location, String publisher) {
+ public VertexAiEmbeddingConnectionDetails(String endpoint, String projectId, String location, String publisher) {
this.projectId = projectId;
this.location = location;
this.publisher = publisher;
@@ -119,7 +122,7 @@ public Builder withPublisher(String publisher) {
return this;
}
- public VertexAiEmbeddigConnectionDetails build() {
+ public VertexAiEmbeddingConnectionDetails build() {
if (!StringUtils.hasText(this.endpoint)) {
if (!StringUtils.hasText(this.location)) {
this.endpoint = DEFAULT_ENDPOINT;
@@ -134,7 +137,7 @@ public VertexAiEmbeddigConnectionDetails build() {
this.publisher = DEFAULT_PUBLISHER;
}
- return new VertexAiEmbeddigConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher);
+ return new VertexAiEmbeddingConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher);
}
}
diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java
index 4d3a8066eb7..4d217509d13 100644
--- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java
+++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java
@@ -35,7 +35,7 @@
import org.springframework.ai.embedding.EmbeddingResultMetadata;
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
import org.springframework.ai.model.ModelOptionsUtils;
-import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder;
@@ -76,9 +76,9 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
private static final List SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG,
MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp"));
- private final VertexAiEmbeddigConnectionDetails connectionDetails;
+ private final VertexAiEmbeddingConnectionDetails connectionDetails;
- public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails,
+ public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiMultimodalEmbeddingOptions defaultEmbeddingOptions) {
Assert.notNull(defaultEmbeddingOptions, "VertexAiMultimodalEmbeddingOptions must not be null");
diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java
index 7f2654db553..0439eb87995 100644
--- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java
+++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java
@@ -29,14 +29,17 @@
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
-import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder;
-import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
+import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -53,16 +56,22 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
public final VertexAiTextEmbeddingOptions defaultOptions;
- private final VertexAiEmbeddigConnectionDetails connectionDetails;
+ private final VertexAiEmbeddingConnectionDetails connectionDetails;
+
+ private final RetryTemplate retryTemplate;
- public VertexAiTextEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails,
+ public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions) {
+ this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
+ }
+ public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
+ VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null");
-
+ Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.defaultOptions = defaultEmbeddingOptions.initializeDefaults();
-
this.connectionDetails = connectionDetails;
+ this.retryTemplate = retryTemplate;
}
@Override
@@ -73,46 +82,23 @@ public float[] embed(Document document) {
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
+ return retryTemplate.execute(context -> {
+ VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;
- VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;
-
- if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
- var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
- finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
- VertexAiTextEmbeddingOptions.class);
- }
-
- try (PredictionServiceClient client = PredictionServiceClient
- .create(this.connectionDetails.getPredictionServiceSettings())) {
-
- EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
-
- PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder()
- .setEndpoint(endpointName.toString());
-
- TextParametersBuilder parametersBuilder = TextParametersBuilder.of();
-
- if (finalOptions.getAutoTruncate() != null) {
- parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate());
- }
-
- if (finalOptions.getDimensions() != null) {
- parametersBuilder.withOutputDimensionality(finalOptions.getDimensions());
+ if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
+ var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
+ finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
+ VertexAiTextEmbeddingOptions.class);
}
- predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));
+ PredictionServiceClient client = createPredictionServiceClient();
- for (int i = 0; i < request.getInstructions().size(); i++) {
+ EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
- TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i))
- .withTaskType(finalOptions.getTaskType().name());
- if (StringUtils.hasText(finalOptions.getTitle())) {
- instanceBuilder.withTitle(finalOptions.getTitle());
- }
- predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
- }
+ PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
+ finalOptions);
- PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
+ PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder);
int index = 0;
int totalTokenCount = 0;
@@ -131,12 +117,53 @@ public EmbeddingResponse call(EmbeddingRequest request) {
}
return new EmbeddingResponse(embeddingList,
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
+ });
+ }
+
+ protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
+ VertexAiTextEmbeddingOptions finalOptions) {
+ PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString());
+
+ TextParametersBuilder parametersBuilder = TextParametersBuilder.of();
+
+ if (finalOptions.getAutoTruncate() != null) {
+ parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate());
}
- catch (Exception e) {
+
+ if (finalOptions.getDimensions() != null) {
+ parametersBuilder.withOutputDimensionality(finalOptions.getDimensions());
+ }
+
+ predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));
+
+ for (int i = 0; i < request.getInstructions().size(); i++) {
+
+ TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i))
+ .withTaskType(finalOptions.getTaskType().name());
+ if (StringUtils.hasText(finalOptions.getTitle())) {
+ instanceBuilder.withTitle(finalOptions.getTitle());
+ }
+ predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
+ }
+ return predictRequestBuilder;
+ }
+
+ // for testing
+ PredictionServiceClient createPredictionServiceClient() {
+ try {
+ return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings());
+ }
+ catch (IOException e) {
throw new RuntimeException(e);
}
}
+ // for testing
+ PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
+ PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
+ return embeddingResponse;
+ }
+
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) {
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.setModel(model);
diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java
index 7f3f1a61639..6caf874323e 100644
--- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java
+++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java
@@ -24,7 +24,7 @@
import org.springframework.ai.embedding.DocumentEmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResultMetadata;
-import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
@@ -213,8 +213,8 @@ void textImageAndVideoEmbedding() {
static class Config {
@Bean
- public VertexAiEmbeddigConnectionDetails connectionDetails() {
- return VertexAiEmbeddigConnectionDetails.builder()
+ public VertexAiEmbeddingConnectionDetails connectionDetails() {
+ return VertexAiEmbeddingConnectionDetails.builder()
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
.build();
@@ -222,7 +222,7 @@ public VertexAiEmbeddigConnectionDetails connectionDetails() {
@Bean
public VertexAiMultimodalEmbeddingModel vertexAiEmbeddingModel(
- VertexAiEmbeddigConnectionDetails connectionDetails) {
+ VertexAiEmbeddingConnectionDetails connectionDetails) {
VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions.builder()
.withModel(VertexAiMultimodalEmbeddingModelName.MULTIMODAL_EMBEDDING_001)
diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java
new file mode 100644
index 00000000000..3ba3d878a0a
--- /dev/null
+++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java
@@ -0,0 +1,57 @@
+package org.springframework.ai.vertexai.embedding.text;
+
+import com.google.cloud.aiplatform.v1.EndpointName;
+import com.google.cloud.aiplatform.v1.PredictRequest;
+import com.google.cloud.aiplatform.v1.PredictResponse;
+import com.google.cloud.aiplatform.v1.PredictionServiceClient;
+import org.springframework.ai.embedding.EmbeddingRequest;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
+import org.springframework.retry.support.RetryTemplate;
+
+import java.io.IOException;
+
+public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel {
+
+ private PredictionServiceClient mockPredictionServiceClient;
+
+ private PredictRequest.Builder mockPredictRequestBuilder;
+
+ public TestVertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
+ VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
+ super(connectionDetails, defaultEmbeddingOptions, retryTemplate);
+ }
+
+ public void setMockPredictionServiceClient(PredictionServiceClient mockPredictionServiceClient) {
+ this.mockPredictionServiceClient = mockPredictionServiceClient;
+ }
+
+ @Override
+ PredictionServiceClient createPredictionServiceClient() {
+ if (mockPredictionServiceClient != null) {
+ return mockPredictionServiceClient;
+ }
+ return super.createPredictionServiceClient();
+ }
+
+ @Override
+ PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
+ if (mockPredictionServiceClient != null) {
+ return mockPredictionServiceClient.predict(predictRequestBuilder.build());
+ }
+ return super.getPredictResponse(client, predictRequestBuilder);
+ }
+
+ public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictRequestBuilder) {
+ this.mockPredictRequestBuilder = mockPredictRequestBuilder;
+ }
+
+ @Override
+ protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
+ VertexAiTextEmbeddingOptions finalOptions) {
+ if (mockPredictRequestBuilder != null) {
+ return mockPredictRequestBuilder;
+ }
+ return super.getPredictRequestBuilder(request, endpointName, finalOptions);
+ }
+
+}
\ No newline at end of file
diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java
index 4c9a9cdcb0f..f98c96b1baa 100644
--- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java
+++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java
@@ -24,7 +24,7 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
-import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
@@ -67,15 +67,15 @@ void defaultEmbedding(String modelName) {
static class Config {
@Bean
- public VertexAiEmbeddigConnectionDetails connectionDetails() {
- return VertexAiEmbeddigConnectionDetails.builder()
+ public VertexAiEmbeddingConnectionDetails connectionDetails() {
+ return VertexAiEmbeddingConnectionDetails.builder()
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
.build();
}
@Bean
- public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails) {
+ public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails) {
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME)
diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java
new file mode 100644
index 00000000000..2f8658c1d14
--- /dev/null
+++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java
@@ -0,0 +1,143 @@
+package org.springframework.ai.vertexai.embedding.text;
+
+import com.google.cloud.aiplatform.v1.PredictRequest;
+import com.google.cloud.aiplatform.v1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1.PredictResponse;
+import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
+import com.google.protobuf.Struct;
+import com.google.protobuf.Value;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.embedding.EmbeddingRequest;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.retry.TransientAiException;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
+import org.springframework.retry.RetryCallback;
+import org.springframework.retry.RetryContext;
+import org.springframework.retry.RetryListener;
+import org.springframework.retry.support.RetryTemplate;
+
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+public class VertexAiTextEmbeddingRetryTests {
+
+ private static class TestRetryListener implements RetryListener {
+
+ int onErrorRetryCount = 0;
+
+ int onSuccessRetryCount = 0;
+
+ @Override
+ public void onSuccess(RetryContext context, RetryCallback callback, T result) {
+ onSuccessRetryCount = context.getRetryCount();
+ }
+
+ @Override
+ public void onError(RetryContext context, RetryCallback callback,
+ Throwable throwable) {
+ onErrorRetryCount = context.getRetryCount();
+ }
+
+ }
+
+ private TestRetryListener retryListener;
+
+ private RetryTemplate retryTemplate;
+
+ @Mock
+ private PredictionServiceClient mockPredictionServiceClient;
+
+ @Mock
+ private VertexAiEmbeddingConnectionDetails mockConnectionDetails;
+
+ @Mock
+ private PredictRequest.Builder mockPredictRequestBuilder;
+
+ @Mock
+ private PredictionServiceSettings mockPredictionServiceSettings;
+
+ private TestVertexAiTextEmbeddingModel embeddingModel;
+
+ @BeforeEach
+ public void setUp() {
+ retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
+ retryListener = new TestRetryListener();
+ retryTemplate.registerListener(retryListener);
+
+ embeddingModel = new TestVertexAiTextEmbeddingModel(mockConnectionDetails,
+ VertexAiTextEmbeddingOptions.builder().build(), retryTemplate);
+ embeddingModel.setMockPredictionServiceClient(mockPredictionServiceClient);
+ embeddingModel.setMockPredictRequestBuilder(mockPredictRequestBuilder);
+ when(mockPredictRequestBuilder.build()).thenReturn(PredictRequest.getDefaultInstance());
+ }
+
+ @Test
+ public void vertexAiEmbeddingTransientError() {
+ // Setup the mock PredictResponse
+ PredictResponse mockResponse = PredictResponse.newBuilder()
+ .addPredictions(Value.newBuilder()
+ .setStructValue(Struct.newBuilder()
+ .putFields("embeddings", Value.newBuilder()
+ .setStructValue(Struct.newBuilder()
+ .putFields("values",
+ Value.newBuilder()
+ .setListValue(com.google.protobuf.ListValue.newBuilder()
+ .addValues(Value.newBuilder().setNumberValue(9.9))
+ .addValues(Value.newBuilder().setNumberValue(8.8))
+ .build())
+ .build())
+ .putFields("statistics",
+ Value.newBuilder()
+ .setStructValue(Struct.newBuilder()
+ .putFields("token_count", Value.newBuilder().setNumberValue(10).build())
+ .build())
+ .build())
+ .build())
+ .build())
+ .build())
+ .build())
+ .build();
+
+ // Setup the mock PredictionServiceClient
+ when(mockPredictionServiceClient.predict(any())).thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(mockResponse);
+
+ EmbeddingResponse result = embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getResults()).hasSize(1);
+ assertThat(result.getResults().get(0).getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+
+ verify(mockPredictRequestBuilder, times(3)).build();
+ }
+
+ @Test
+ public void vertexAiEmbeddingNonTransientError() {
+ // Setup the mock PredictionServiceClient to throw a non-transient error
+ when(mockPredictionServiceClient.predict(any()))
+ .thenThrow(new RuntimeException("Non Transient Error"));
+
+ // Assert that a RuntimeException is thrown and not retried
+ assertThrows(RuntimeException.class, () -> embeddingModel
+ .call(new EmbeddingRequest(List.of("text1", "text2"), null)));
+
+ // Verify that predict was called only once (no retries for non-transient errors)
+ verify(mockPredictionServiceClient, times(1)).predict(any());
+ }
+
+}
diff --git a/models/spring-ai-vertex-ai-gemini/pom.xml b/models/spring-ai-vertex-ai-gemini/pom.xml
index 3ebd175128f..c903c8dc0be 100644
--- a/models/spring-ai-vertex-ai-gemini/pom.xml
+++ b/models/spring-ai-vertex-ai-gemini/pom.xml
@@ -52,6 +52,12 @@
${project.parent.version}
+
+ org.springframework.ai
+ spring-ai-retry
+ ${project.parent.version}
+
+
org.springframework
spring-web
diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java
index b7ea5051d3d..ece66fc31c3 100644
--- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java
+++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java
@@ -15,14 +15,16 @@
*/
package org.springframework.ai.vertexai.gemini;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonInclude.Include;
+import com.google.cloud.vertexai.VertexAI;
+import com.google.cloud.vertexai.api.*;
+import com.google.cloud.vertexai.api.Candidate.FinishReason;
+import com.google.cloud.vertexai.generativeai.GenerativeModel;
+import com.google.cloud.vertexai.generativeai.PartMaker;
+import com.google.cloud.vertexai.generativeai.ResponseStream;
+import com.google.protobuf.Struct;
+import com.google.protobuf.util.JsonFormat;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
@@ -42,35 +44,23 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.lang.NonNull;
+import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
-
-import com.fasterxml.jackson.annotation.JsonInclude;
-import com.fasterxml.jackson.annotation.JsonInclude.Include;
-import com.google.cloud.vertexai.VertexAI;
-import com.google.cloud.vertexai.api.Candidate;
-import com.google.cloud.vertexai.api.Candidate.FinishReason;
-import com.google.cloud.vertexai.api.Content;
-import com.google.cloud.vertexai.api.FunctionCall;
-import com.google.cloud.vertexai.api.FunctionDeclaration;
-import com.google.cloud.vertexai.api.FunctionResponse;
-import com.google.cloud.vertexai.api.GenerateContentResponse;
-import com.google.cloud.vertexai.api.GenerationConfig;
-import com.google.cloud.vertexai.api.Part;
-import com.google.cloud.vertexai.api.Schema;
-import com.google.cloud.vertexai.api.Tool;
-import com.google.cloud.vertexai.generativeai.GenerativeModel;
-import com.google.cloud.vertexai.generativeai.PartMaker;
-import com.google.cloud.vertexai.generativeai.ResponseStream;
-import com.google.protobuf.Struct;
-import com.google.protobuf.util.JsonFormat;
-
import reactor.core.publisher.Flux;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
/**
* @author Christian Tzolov
* @author Grogdunn
@@ -86,6 +76,11 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements
private final VertexAiGeminiChatOptions defaultOptions;
+ /**
+ * The retry template used to retry the API calls.
+ */
+ private final RetryTemplate retryTemplate;
+
private final GenerationConfig generationConfig;
public enum GeminiMessageType {
@@ -152,43 +147,52 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions opti
public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options,
FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks) {
+ this(vertexAI, options, functionCallbackContext, toolFunctionCallbacks, RetryUtils.DEFAULT_RETRY_TEMPLATE);
+ }
+
+ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options,
+ FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks,
+ RetryTemplate retryTemplate) {
super(functionCallbackContext, options, toolFunctionCallbacks);
Assert.notNull(vertexAI, "VertexAI must not be null");
Assert.notNull(options, "VertexAiGeminiChatOptions must not be null");
Assert.notNull(options.getModel(), "VertexAiGeminiChatOptions.modelName must not be null");
+ Assert.notNull(retryTemplate, "RetryTemplate must not be null");
this.vertexAI = vertexAI;
this.defaultOptions = options;
this.generationConfig = toGenerationConfig(options);
+ this.retryTemplate = retryTemplate;
}
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
@Override
public ChatResponse call(Prompt prompt) {
+ return retryTemplate.execute(context -> {
+ var geminiRequest = createGeminiRequest(prompt);
- var geminiRequest = createGeminiRequest(prompt);
+ GenerateContentResponse response = this.getContentResponse(geminiRequest);
- GenerateContentResponse response = this.getContentResponse(geminiRequest);
-
- List generations = response.getCandidatesList()
- .stream()
- .map(this::responseCandiateToGeneration)
- .flatMap(List::stream)
- .toList();
+ List generations = response.getCandidatesList()
+ .stream()
+ .map(this::responseCandiateToGeneration)
+ .flatMap(List::stream)
+ .toList();
- ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response));
+ ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response));
- if (!isProxyToolCalls(prompt, this.defaultOptions)
- && isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) {
- var toolCallConversation = handleToolCalls(prompt, chatResponse);
- // Recursively call the call method with the tool call message
- // conversation that contains the call responses.
- return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
- }
+ if (!isProxyToolCalls(prompt, this.defaultOptions)
+ && isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) {
+ var toolCallConversation = handleToolCalls(prompt, chatResponse);
+ // Recursively call the call method with the tool call message
+ // conversation that contains the call responses.
+ return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
+ }
- return chatResponse;
+ return chatResponse;
+ });
}
@Override
@@ -525,7 +529,14 @@ private static Schema jsonToSchema(String json) {
}
}
- private GenerateContentResponse getContentResponse(GeminiRequest request) {
+ /**
+ * Generates the content response based on the provided Gemini request. Package
+ * protected for testing purposes.
+ * @param request the GeminiRequest containing the content and model information
+ * @return a GenerateContentResponse containing the generated content
+ * @throws RuntimeException if content generation fails
+ */
+ GenerateContentResponse getContentResponse(GeminiRequest request) {
try {
return request.model.generateContent(request.contents);
}
diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java
new file mode 100644
index 00000000000..5c3fcffdbbb
--- /dev/null
+++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java
@@ -0,0 +1,45 @@
+package org.springframework.ai.vertexai.gemini;
+
+import com.google.cloud.vertexai.VertexAI;
+import com.google.cloud.vertexai.api.GenerateContentResponse;
+import com.google.cloud.vertexai.generativeai.GenerativeModel;
+import org.springframework.ai.model.function.FunctionCallback;
+import org.springframework.ai.model.function.FunctionCallbackContext;
+import org.springframework.retry.support.RetryTemplate;
+
+import java.io.IOException;
+import java.util.List;
+
+public class TestVertexAiGeminiChatModel extends VertexAiGeminiChatModel {
+
+ private GenerativeModel mockGenerativeModel;
+
+ public TestVertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options,
+ FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks,
+ RetryTemplate retryTemplate) {
+ super(vertexAI, options, functionCallbackContext, toolFunctionCallbacks, retryTemplate);
+ }
+
+ @Override
+ GenerateContentResponse getContentResponse(GeminiRequest request) {
+ if (mockGenerativeModel != null) {
+ try {
+ return mockGenerativeModel.generateContent(request.contents());
+ }
+ catch (IOException e) {
+ // Should not be thrown by testing class
+ throw new RuntimeException("Failed to generate content", e);
+ }
+ catch (RuntimeException e) {
+ // Re-throw RuntimeExceptions (including TransientAiException) as is
+ throw e;
+ }
+ }
+ return super.getContentResponse(request);
+ }
+
+ public void setMockGenerativeModel(GenerativeModel mockGenerativeModel) {
+ this.mockGenerativeModel = mockGenerativeModel;
+ }
+
+}
\ No newline at end of file
diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java
new file mode 100644
index 00000000000..924342384a3
--- /dev/null
+++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java
@@ -0,0 +1,120 @@
+package org.springframework.ai.vertexai.gemini;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.Mockito.*;
+
+import com.google.cloud.vertexai.VertexAI;
+import com.google.cloud.vertexai.api.Candidate;
+import com.google.cloud.vertexai.api.Content;
+import com.google.cloud.vertexai.api.GenerateContentResponse;
+import com.google.cloud.vertexai.api.Part;
+import com.google.cloud.vertexai.generativeai.GenerativeModel;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.retry.TransientAiException;
+import org.springframework.retry.RetryCallback;
+import org.springframework.retry.RetryContext;
+import org.springframework.retry.RetryListener;
+import org.springframework.retry.support.RetryTemplate;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+@SuppressWarnings("unchecked")
+@ExtendWith(MockitoExtension.class)
+public class VertexAiGeminiRetryTests {
+
+ private static class TestRetryListener implements RetryListener {
+
+ int onErrorRetryCount = 0;
+
+ int onSuccessRetryCount = 0;
+
+ @Override
+ public void onSuccess(RetryContext context, RetryCallback callback, T result) {
+ onSuccessRetryCount = context.getRetryCount();
+ }
+
+ @Override
+ public void onError(RetryContext context, RetryCallback callback,
+ Throwable throwable) {
+ onErrorRetryCount = context.getRetryCount();
+ }
+
+ }
+
+ private TestRetryListener retryListener;
+
+ private RetryTemplate retryTemplate;
+
+ @Mock
+ private VertexAI vertexAI;
+
+ @Mock
+ private GenerativeModel mockGenerativeModel;
+
+ private TestVertexAiGeminiChatModel chatModel;
+
+ @BeforeEach
+ public void setUp() {
+ retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
+ retryListener = new TestRetryListener();
+ retryTemplate.registerListener(retryListener);
+
+ chatModel = new TestVertexAiGeminiChatModel(vertexAI,
+ VertexAiGeminiChatOptions.builder()
+ .withTemperature(0.7)
+ .withTopP(1.0)
+ .withModel(VertexAiGeminiChatModel.ChatModel.GEMINI_PRO.getValue())
+ .build(),
+ null, Collections.emptyList(), retryTemplate);
+
+ chatModel.setMockGenerativeModel(mockGenerativeModel);
+ }
+
+ @Test
+ public void vertexAiGeminiChatTransientError() throws IOException {
+ // Create a mocked successful response
+ GenerateContentResponse mockedResponse = GenerateContentResponse.newBuilder()
+ .addCandidates(Candidate.newBuilder()
+ .setContent(Content.newBuilder().addParts(Part.newBuilder().setText("Response").build()).build())
+ .build())
+ .build();
+
+ when(mockGenerativeModel.generateContent(any(List.class)))
+ .thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(mockedResponse);
+
+ // Call the chat model
+ ChatResponse result = chatModel.call(new Prompt("test prompt"));
+
+ // Assertions
+ assertThat(result).isNotNull();
+ assertThat(result.getResult().getOutput().getContent()).isEqualTo("Response");
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void vertexAiGeminiChatNonTransientError() throws Exception {
+ // Set up the mock GenerativeModel to throw a non-transient RuntimeException
+ when(mockGenerativeModel.generateContent(any(List.class)))
+ .thenThrow(new RuntimeException("Non Transient Error"));
+
+ // Assert that a RuntimeException is thrown when calling the chat model
+ assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("test prompt")));
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java
index 16bde0c6ce6..668975879f3 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java
@@ -17,14 +17,18 @@
import java.io.IOException;
-import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModel;
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
+import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@@ -36,20 +40,22 @@
* @author Christian Tzolov
* @since 1.0.0
*/
+@AutoConfiguration(after = { SpringAiRetryAutoConfiguration.class })
@ConditionalOnClass({ VertexAI.class, VertexAiTextEmbeddingModel.class })
@EnableConfigurationProperties({ VertexAiEmbeddingConnectionProperties.class, VertexAiTextEmbeddingProperties.class,
- VertexAiMultimodalEmbeddingProperties.class, })
+ VertexAiMultimodalEmbeddingProperties.class })
+@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class })
public class VertexAiEmbeddingAutoConfiguration {
@Bean
@ConditionalOnMissingBean
- public VertexAiEmbeddigConnectionDetails connectionDetails(
+ public VertexAiEmbeddingConnectionDetails connectionDetails(
VertexAiEmbeddingConnectionProperties connectionProperties) {
Assert.hasText(connectionProperties.getProjectId(), "Vertex AI project-id must be set!");
Assert.hasText(connectionProperties.getLocation(), "Vertex AI location must be set!");
- var connectionBuilder = VertexAiEmbeddigConnectionDetails.builder()
+ var connectionBuilder = VertexAiEmbeddingConnectionDetails.builder()
.withProjectId(connectionProperties.getProjectId())
.withLocation(connectionProperties.getLocation());
@@ -65,17 +71,17 @@ public VertexAiEmbeddigConnectionDetails connectionDetails(
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = VertexAiTextEmbeddingProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
- public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddigConnectionDetails connectionDetails,
- VertexAiTextEmbeddingProperties textEmbeddingProperties) throws IOException {
+ public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddingConnectionDetails connectionDetails,
+ VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate) {
- return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions());
+ return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(), retryTemplate);
}
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = VertexAiMultimodalEmbeddingProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
- public VertexAiMultimodalEmbeddingModel multimodalEmbedding(VertexAiEmbeddigConnectionDetails connectionDetails,
+ public VertexAiMultimodalEmbeddingModel multimodalEmbedding(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiMultimodalEmbeddingProperties multimodalEmbeddingProperties) throws IOException {
return new VertexAiMultimodalEmbeddingModel(connectionDetails, multimodalEmbeddingProperties.getOptions());
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java
index 6003594f7bf..17dc2121a87 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java
@@ -18,16 +18,20 @@
import java.io.IOException;
import java.util.List;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
+import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
@@ -40,10 +44,12 @@
*
* @author Christian Tzolov
* @author Soby Chacko
- * @since 0.8.0
+ * @since 1.0.0
*/
+@AutoConfiguration(after = { SpringAiRetryAutoConfiguration.class })
@ConditionalOnClass({ VertexAI.class, VertexAiGeminiChatModel.class })
@EnableConfigurationProperties({ VertexAiGeminiChatProperties.class, VertexAiGeminiConnectionProperties.class })
+@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class })
public class VertexAiGeminiAutoConfiguration {
@Bean
@@ -79,12 +85,12 @@ public VertexAI vertexAi(VertexAiGeminiConnectionProperties connectionProperties
@ConditionalOnProperty(prefix = VertexAiGeminiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public VertexAiGeminiChatModel vertexAiGeminiChat(VertexAI vertexAi, VertexAiGeminiChatProperties chatProperties,
- List toolFunctionCallbacks, ApplicationContext context) {
+ List toolFunctionCallbacks, ApplicationContext context, RetryTemplate retryTemplate) {
FunctionCallbackContext functionCallbackContext = springAiFunctionManager(context);
return new VertexAiGeminiChatModel(vertexAi, chatProperties.getOptions(), functionCallbackContext,
- toolFunctionCallbacks);
+ toolFunctionCallbacks, retryTemplate);
}
/**