Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions models/spring-ai-vertex-ai-embedding/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@
<artifactId>spring-boot-starter-logging</artifactId>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>

<!-- test dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
*/
package org.springframework.ai.vertexai.embedding.text;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.protobuf.Value;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
Expand All @@ -28,7 +30,12 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
Expand All @@ -39,12 +46,13 @@
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.protobuf.Value;

import io.micrometer.observation.ObservationRegistry;

/**
* A class representing a Vertex AI Text Embedding Model.
Expand All @@ -55,24 +63,44 @@
*/
public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {

private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();

public final VertexAiTextEmbeddingOptions defaultOptions;

private final VertexAiEmbeddingConnectionDetails connectionDetails;

private final RetryTemplate retryTemplate;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions) {
this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
this(connectionDetails, defaultEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP);
}

public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
this.defaultOptions = defaultEmbeddingOptions.initializeDefaults();
this.connectionDetails = connectionDetails;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
Expand All @@ -83,42 +111,64 @@ public float[] embed(Document document) {

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
return retryTemplate.execute(context -> {
VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;

if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
VertexAiTextEmbeddingOptions.class);
}
final VertexAiTextEmbeddingOptions finalOptions = mergedOptions(request);

PredictionServiceClient client = createPredictionServiceClient();
var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.provider(AiProvider.VERTEX_AI.value())
.requestOptions(finalOptions)
.build();

EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
PredictionServiceClient client = createPredictionServiceClient();

PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
finalOptions);
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());

PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder);
PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
finalOptions);

int index = 0;
int totalTokenCount = 0;
List<Embedding> embeddingList = new ArrayList<>();
for (Value prediction : embeddingResponse.getPredictionsList()) {
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue();
PredictResponse embeddingResponse = retryTemplate
.execute(context -> getPredictResponse(client, predictRequestBuilder));

Value values = embeddings.getStructValue().getFieldsOrThrow("values");
int index = 0;
int totalTokenCount = 0;
List<Embedding> embeddingList = new ArrayList<>();
for (Value prediction : embeddingResponse.getPredictionsList()) {
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue();

float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);
Value values = embeddings.getStructValue().getFieldsOrThrow("values");

embeddingList.add(new Embedding(vectorValues, index++));
}
return new EmbeddingResponse(embeddingList,
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
});
float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);

embeddingList.add(new Embedding(vectorValues, index++));
}
EmbeddingResponse response = new EmbeddingResponse(embeddingList,
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));

observationContext.setResponse(response);

return response;
});
}

private VertexAiTextEmbeddingOptions mergedOptions(EmbeddingRequest request) {

VertexAiTextEmbeddingOptions mergedOptions = this.defaultOptions;

if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
mergedOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
VertexAiTextEmbeddingOptions.class);
}

return mergedOptions;
}

protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
Expand Down Expand Up @@ -183,4 +233,13 @@ public int dimensions() {
.collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName,
VertexAiTextEmbeddingModelName::getDimensions));

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.vertexai.embedding.text;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;

/**
* Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}.
*
* @author Christian Tzolov
*/
@SpringBootTest(classes = VertexAiTextEmbeddingModelObservationIT.Config.class)
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
public class VertexAiTextEmbeddingModelObservationIT {

@Autowired
TestObservationRegistry observationRegistry;

@Autowired
VertexAiTextEmbeddingModel embeddingModel;

@Test
void observationForEmbeddingOperation() {

var options = VertexAiTextEmbeddingOptions.builder()
.withModel(VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
.withDimensions(768)
.build();

EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options);

EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest);
assertThat(embeddingResponse.getResults()).isNotEmpty();

EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

TestObservationRegistryAssert.assertThat(observationRegistry)
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME)
.that()
.hasContextualNameEqualTo("embedding " + VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
AiOperationType.EMBEDDING.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.VERTEX_AI.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(),
VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel())
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(), "768")
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
.hasBeenStarted()
.hasBeenStopped();
}

@SpringBootConfiguration
static class Config {

@Bean
public TestObservationRegistry observationRegistry() {
return TestObservationRegistry.create();
}

@Bean
public VertexAiEmbeddingConnectionDetails connectionDetails() {
return VertexAiEmbeddingConnectionDetails.builder()
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
.build();
}

@Bean
public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
ObservationRegistry observationRegistry) {

VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME)
.build();

return new VertexAiTextEmbeddingModel(connectionDetails, options, RetryUtils.DEFAULT_RETRY_TEMPLATE,
observationRegistry);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import java.io.IOException;

import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModel;
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
Expand All @@ -34,6 +36,8 @@

import com.google.cloud.vertexai.VertexAI;

import io.micrometer.observation.ObservationRegistry;

/**
* Auto-configuration for Vertex AI Gemini Chat.
*
Expand Down Expand Up @@ -73,9 +77,16 @@ public VertexAiEmbeddingConnectionDetails connectionDetails(
@ConditionalOnProperty(prefix = VertexAiTextEmbeddingProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate) {
VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {

var embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(),
retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));

observationConvention.ifAvailable(embeddingModel::setObservationConvention);

return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(), retryTemplate);
return embeddingModel;
}

@Bean
Expand Down