From 6449bf6497604e49349d46c9f214ec9d5f20a3fd Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 8 Oct 2024 10:57:31 +0200 Subject: [PATCH] Observability instrumentation for VertexAI Text Embedding --- models/spring-ai-vertex-ai-embedding/pom.xml | 6 + .../text/VertexAiTextEmbeddingModel.java | 135 +++++++++++++----- ...rtexAiTextEmbeddingModelObservationIT.java | 125 ++++++++++++++++ .../VertexAiEmbeddingAutoConfiguration.java | 15 +- 4 files changed, 241 insertions(+), 40 deletions(-) create mode 100644 models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java diff --git a/models/spring-ai-vertex-ai-embedding/pom.xml b/models/spring-ai-vertex-ai-embedding/pom.xml index 08526534078..5988fcf4476 100644 --- a/models/spring-ai-vertex-ai-embedding/pom.xml +++ b/models/spring-ai-vertex-ai-embedding/pom.xml @@ -76,6 +76,12 @@ spring-boot-starter-logging + + io.micrometer + micrometer-observation-test + test + + org.springframework.ai diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 40920bc4c6c..26da5fea0c0 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -15,11 +15,13 @@ */ package org.springframework.ai.vertexai.embedding.text; -import com.google.cloud.aiplatform.v1.EndpointName; -import com.google.cloud.aiplatform.v1.PredictRequest; -import com.google.cloud.aiplatform.v1.PredictResponse; -import com.google.cloud.aiplatform.v1.PredictionServiceClient; -import com.google.protobuf.Value; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -28,7 +30,12 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; @@ -39,12 +46,13 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Value; + +import io.micrometer.observation.ObservationRegistry; /** * A class representing a Vertex AI Text Embedding Model. @@ -55,12 +63,24 @@ */ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + public final VertexAiTextEmbeddingOptions defaultOptions; private final VertexAiEmbeddingConnectionDetails connectionDetails; private final RetryTemplate retryTemplate; + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiTextEmbeddingOptions defaultEmbeddingOptions) { this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); @@ -68,11 +88,19 @@ public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionD public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) { + this(connectionDetails, defaultEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP); + } + + public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, + VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null"); Assert.notNull(retryTemplate, "retryTemplate must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); this.defaultOptions = defaultEmbeddingOptions.initializeDefaults(); this.connectionDetails = connectionDetails; this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; } @Override @@ -83,42 +111,64 @@ public float[] embed(Document document) { @Override public EmbeddingResponse call(EmbeddingRequest request) { - return retryTemplate.execute(context -> { - VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions; - if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) { - var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build(); - finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy, - VertexAiTextEmbeddingOptions.class); - } + final VertexAiTextEmbeddingOptions finalOptions = mergedOptions(request); - PredictionServiceClient client = createPredictionServiceClient(); + var observationContext = EmbeddingModelObservationContext.builder() + .embeddingRequest(request) + .provider(AiProvider.VERTEX_AI.value()) + .requestOptions(finalOptions) + .build(); - EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); + return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + PredictionServiceClient client = createPredictionServiceClient(); - PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, - finalOptions); + EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); - PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder); + PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, + finalOptions); - int index = 0; - int totalTokenCount = 0; - List embeddingList = new ArrayList<>(); - for (Value prediction : embeddingResponse.getPredictionsList()) { - Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings"); - Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics"); - Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count"); - totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue(); + PredictResponse embeddingResponse = retryTemplate + .execute(context -> getPredictResponse(client, predictRequestBuilder)); - Value values = embeddings.getStructValue().getFieldsOrThrow("values"); + int index = 0; + int totalTokenCount = 0; + List embeddingList = new ArrayList<>(); + for (Value prediction : embeddingResponse.getPredictionsList()) { + Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings"); + Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics"); + Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count"); + totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue(); - float[] vectorValues = VertexAiEmbeddingUtils.toVector(values); + Value values = embeddings.getStructValue().getFieldsOrThrow("values"); - embeddingList.add(new Embedding(vectorValues, index++)); - } - return new EmbeddingResponse(embeddingList, - generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); - }); + float[] vectorValues = VertexAiEmbeddingUtils.toVector(values); + + embeddingList.add(new Embedding(vectorValues, index++)); + } + EmbeddingResponse response = new EmbeddingResponse(embeddingList, + generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); + + observationContext.setResponse(response); + + return response; + }); + } + + private VertexAiTextEmbeddingOptions mergedOptions(EmbeddingRequest request) { + + VertexAiTextEmbeddingOptions mergedOptions = this.defaultOptions; + + if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) { + var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build(); + mergedOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy, + VertexAiTextEmbeddingOptions.class); + } + + return mergedOptions; } protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName, @@ -183,4 +233,13 @@ public int dimensions() { .collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName, VertexAiTextEmbeddingModelName::getDimensions)); + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + } \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java new file mode 100644 index 00000000000..f6ac7c5b531 --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java @@ -0,0 +1,125 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.text; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + +/** + * Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}. + * + * @author Christian Tzolov + */ +@SpringBootTest(classes = VertexAiTextEmbeddingModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +public class VertexAiTextEmbeddingModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + VertexAiTextEmbeddingModel embeddingModel; + + @Test + void observationForEmbeddingOperation() { + + var options = VertexAiTextEmbeddingOptions.builder() + .withModel(VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName()) + .withDimensions(768) + .build(); + + EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); + + EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).isNotEmpty(); + + EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("embedding " + VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.EMBEDDING.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.VERTEX_AI.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), + VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(), "768") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public VertexAiEmbeddingConnectionDetails connectionDetails() { + return VertexAiEmbeddingConnectionDetails.builder() + .withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID")) + .withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION")) + .build(); + } + + @Bean + public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, + ObservationRegistry observationRegistry) { + + VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() + .withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME) + .build(); + + return new VertexAiTextEmbeddingModel(connectionDetails, options, RetryUtils.DEFAULT_RETRY_TEMPLATE, + observationRegistry); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java index de18b8cf9b6..b51c0e718cc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java @@ -18,9 +18,11 @@ import java.io.IOException; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModel; import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.ImportAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -34,6 +36,8 @@ import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; + /** * Auto-configuration for Vertex AI Gemini Chat. * @@ -73,9 +77,16 @@ public VertexAiEmbeddingConnectionDetails connectionDetails( @ConditionalOnProperty(prefix = VertexAiTextEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddingConnectionDetails connectionDetails, - VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate) { + VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate, + ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + + var embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(), + retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + + observationConvention.ifAvailable(embeddingModel::setObservationConvention); - return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(), retryTemplate); + return embeddingModel; } @Bean