diff --git a/models/spring-ai-vertex-ai-embedding/pom.xml b/models/spring-ai-vertex-ai-embedding/pom.xml
index 08526534078..5988fcf4476 100644
--- a/models/spring-ai-vertex-ai-embedding/pom.xml
+++ b/models/spring-ai-vertex-ai-embedding/pom.xml
@@ -76,6 +76,12 @@
spring-boot-starter-logging
+
+ io.micrometer
+ micrometer-observation-test
+ test
+
+
org.springframework.ai
diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java
index 40920bc4c6c..26da5fea0c0 100644
--- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java
+++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java
@@ -15,11 +15,13 @@
*/
package org.springframework.ai.vertexai.embedding.text;
-import com.google.cloud.aiplatform.v1.EndpointName;
-import com.google.cloud.aiplatform.v1.PredictRequest;
-import com.google.cloud.aiplatform.v1.PredictResponse;
-import com.google.cloud.aiplatform.v1.PredictionServiceClient;
-import com.google.protobuf.Value;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
@@ -28,7 +30,12 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
+import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
@@ -39,12 +46,13 @@
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
+import com.google.cloud.aiplatform.v1.EndpointName;
+import com.google.cloud.aiplatform.v1.PredictRequest;
+import com.google.cloud.aiplatform.v1.PredictResponse;
+import com.google.cloud.aiplatform.v1.PredictionServiceClient;
+import com.google.protobuf.Value;
+
+import io.micrometer.observation.ObservationRegistry;
/**
* A class representing a Vertex AI Text Embedding Model.
@@ -55,12 +63,24 @@
*/
public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
+ private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
+
public final VertexAiTextEmbeddingOptions defaultOptions;
private final VertexAiEmbeddingConnectionDetails connectionDetails;
private final RetryTemplate retryTemplate;
+ /**
+ * Observation registry used for instrumentation.
+ */
+ private final ObservationRegistry observationRegistry;
+
+ /**
+ * Conventions to use for generating observations.
+ */
+ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
+
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions) {
this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
@@ -68,11 +88,19 @@ public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionD
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
+ this(connectionDetails, defaultEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP);
+ }
+
+ public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
+ VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate,
+ ObservationRegistry observationRegistry) {
Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
+ Assert.notNull(observationRegistry, "observationRegistry must not be null");
this.defaultOptions = defaultEmbeddingOptions.initializeDefaults();
this.connectionDetails = connectionDetails;
this.retryTemplate = retryTemplate;
+ this.observationRegistry = observationRegistry;
}
@Override
@@ -83,42 +111,64 @@ public float[] embed(Document document) {
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
- return retryTemplate.execute(context -> {
- VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;
- if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
- var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
- finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
- VertexAiTextEmbeddingOptions.class);
- }
+ final VertexAiTextEmbeddingOptions finalOptions = mergedOptions(request);
- PredictionServiceClient client = createPredictionServiceClient();
+ var observationContext = EmbeddingModelObservationContext.builder()
+ .embeddingRequest(request)
+ .provider(AiProvider.VERTEX_AI.value())
+ .requestOptions(finalOptions)
+ .build();
- EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
+ return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
+ .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
+ this.observationRegistry)
+ .observe(() -> {
+ PredictionServiceClient client = createPredictionServiceClient();
- PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
- finalOptions);
+ EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
- PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder);
+ PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
+ finalOptions);
- int index = 0;
- int totalTokenCount = 0;
- List embeddingList = new ArrayList<>();
- for (Value prediction : embeddingResponse.getPredictionsList()) {
- Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
- Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
- Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
- totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue();
+ PredictResponse embeddingResponse = retryTemplate
+ .execute(context -> getPredictResponse(client, predictRequestBuilder));
- Value values = embeddings.getStructValue().getFieldsOrThrow("values");
+ int index = 0;
+ int totalTokenCount = 0;
+ List embeddingList = new ArrayList<>();
+ for (Value prediction : embeddingResponse.getPredictionsList()) {
+ Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
+ Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
+ Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
+ totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue();
- float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);
+ Value values = embeddings.getStructValue().getFieldsOrThrow("values");
- embeddingList.add(new Embedding(vectorValues, index++));
- }
- return new EmbeddingResponse(embeddingList,
- generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
- });
+ float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);
+
+ embeddingList.add(new Embedding(vectorValues, index++));
+ }
+ EmbeddingResponse response = new EmbeddingResponse(embeddingList,
+ generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
+
+ observationContext.setResponse(response);
+
+ return response;
+ });
+ }
+
+ private VertexAiTextEmbeddingOptions mergedOptions(EmbeddingRequest request) {
+
+ VertexAiTextEmbeddingOptions mergedOptions = this.defaultOptions;
+
+ if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
+ var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
+ mergedOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
+ VertexAiTextEmbeddingOptions.class);
+ }
+
+ return mergedOptions;
}
protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
@@ -183,4 +233,13 @@ public int dimensions() {
.collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName,
VertexAiTextEmbeddingModelName::getDimensions));
+ /**
+ * Use the provided convention for reporting observation data
+ * @param observationConvention The provided convention
+ */
+ public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
+ Assert.notNull(observationConvention, "observationConvention cannot be null");
+ this.observationConvention = observationConvention;
+ }
+
}
\ No newline at end of file
diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java
new file mode 100644
index 00000000000..f6ac7c5b531
--- /dev/null
+++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java
@@ -0,0 +1,125 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.vertexai.embedding.text;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.List;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.springframework.ai.embedding.EmbeddingRequest;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.embedding.EmbeddingResponseMetadata;
+import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames;
+import org.springframework.ai.observation.conventions.AiOperationType;
+import org.springframework.ai.observation.conventions.AiProvider;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.context.annotation.Bean;
+
+import io.micrometer.observation.ObservationRegistry;
+import io.micrometer.observation.tck.TestObservationRegistry;
+import io.micrometer.observation.tck.TestObservationRegistryAssert;
+
+/**
+ * Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}.
+ *
+ * @author Christian Tzolov
+ */
+@SpringBootTest(classes = VertexAiTextEmbeddingModelObservationIT.Config.class)
+@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
+@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
+public class VertexAiTextEmbeddingModelObservationIT {
+
+ @Autowired
+ TestObservationRegistry observationRegistry;
+
+ @Autowired
+ VertexAiTextEmbeddingModel embeddingModel;
+
+ @Test
+ void observationForEmbeddingOperation() {
+
+ var options = VertexAiTextEmbeddingOptions.builder()
+ .withModel(VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
+ .withDimensions(768)
+ .build();
+
+ EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options);
+
+ EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest);
+ assertThat(embeddingResponse.getResults()).isNotEmpty();
+
+ EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata();
+ assertThat(responseMetadata).isNotNull();
+
+ TestObservationRegistryAssert.assertThat(observationRegistry)
+ .doesNotHaveAnyRemainingCurrentObservation()
+ .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME)
+ .that()
+ .hasContextualNameEqualTo("embedding " + VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
+ .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
+ AiOperationType.EMBEDDING.value())
+ .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.VERTEX_AI.value())
+ .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(),
+ VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
+ .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel())
+ .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(), "768")
+ .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
+ String.valueOf(responseMetadata.getUsage().getPromptTokens()))
+ .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
+ String.valueOf(responseMetadata.getUsage().getTotalTokens()))
+ .hasBeenStarted()
+ .hasBeenStopped();
+ }
+
+ @SpringBootConfiguration
+ static class Config {
+
+ @Bean
+ public TestObservationRegistry observationRegistry() {
+ return TestObservationRegistry.create();
+ }
+
+ @Bean
+ public VertexAiEmbeddingConnectionDetails connectionDetails() {
+ return VertexAiEmbeddingConnectionDetails.builder()
+ .withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
+ .withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
+ .build();
+ }
+
+ @Bean
+ public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
+ ObservationRegistry observationRegistry) {
+
+ VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
+ .withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME)
+ .build();
+
+ return new VertexAiTextEmbeddingModel(connectionDetails, options, RetryUtils.DEFAULT_RETRY_TEMPLATE,
+ observationRegistry);
+ }
+
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java
index de18b8cf9b6..b51c0e718cc 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java
@@ -18,9 +18,11 @@
import java.io.IOException;
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModel;
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel;
+import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
@@ -34,6 +36,8 @@
import com.google.cloud.vertexai.VertexAI;
+import io.micrometer.observation.ObservationRegistry;
+
/**
* Auto-configuration for Vertex AI Gemini Chat.
*
@@ -73,9 +77,16 @@ public VertexAiEmbeddingConnectionDetails connectionDetails(
@ConditionalOnProperty(prefix = VertexAiTextEmbeddingProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddingConnectionDetails connectionDetails,
- VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate) {
+ VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate,
+ ObjectProvider observationRegistry,
+ ObjectProvider observationConvention) {
+
+ var embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(),
+ retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));
+
+ observationConvention.ifAvailable(embeddingModel::setObservationConvention);
- return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(), retryTemplate);
+ return embeddingModel;
}
@Bean