Skip to content

Commit 6449bf6

Browse files
committed
Observability instrumentation for VertexAI Text Embedding
1 parent 82a69dc commit 6449bf6

File tree

4 files changed

+241
-40
lines changed

4 files changed

+241
-40
lines changed

models/spring-ai-vertex-ai-embedding/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@
7676
<artifactId>spring-boot-starter-logging</artifactId>
7777
</dependency>
7878

79+
<dependency>
80+
<groupId>io.micrometer</groupId>
81+
<artifactId>micrometer-observation-test</artifactId>
82+
<scope>test</scope>
83+
</dependency>
84+
7985
<!-- test dependencies -->
8086
<dependency>
8187
<groupId>org.springframework.ai</groupId>

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java

Lines changed: 97 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
*/
1616
package org.springframework.ai.vertexai.embedding.text;
1717

18-
import com.google.cloud.aiplatform.v1.EndpointName;
19-
import com.google.cloud.aiplatform.v1.PredictRequest;
20-
import com.google.cloud.aiplatform.v1.PredictResponse;
21-
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
22-
import com.google.protobuf.Value;
18+
import java.io.IOException;
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.stream.Collectors;
23+
import java.util.stream.Stream;
24+
2325
import org.springframework.ai.chat.metadata.Usage;
2426
import org.springframework.ai.document.Document;
2527
import org.springframework.ai.embedding.AbstractEmbeddingModel;
@@ -28,7 +30,12 @@
2830
import org.springframework.ai.embedding.EmbeddingRequest;
2931
import org.springframework.ai.embedding.EmbeddingResponse;
3032
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
33+
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
34+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
35+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
36+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
3137
import org.springframework.ai.model.ModelOptionsUtils;
38+
import org.springframework.ai.observation.conventions.AiProvider;
3239
import org.springframework.ai.retry.RetryUtils;
3340
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
3441
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
@@ -39,12 +46,13 @@
3946
import org.springframework.util.Assert;
4047
import org.springframework.util.StringUtils;
4148

42-
import java.io.IOException;
43-
import java.util.ArrayList;
44-
import java.util.List;
45-
import java.util.Map;
46-
import java.util.stream.Collectors;
47-
import java.util.stream.Stream;
49+
import com.google.cloud.aiplatform.v1.EndpointName;
50+
import com.google.cloud.aiplatform.v1.PredictRequest;
51+
import com.google.cloud.aiplatform.v1.PredictResponse;
52+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
53+
import com.google.protobuf.Value;
54+
55+
import io.micrometer.observation.ObservationRegistry;
4856

4957
/**
5058
* A class representing a Vertex AI Text Embedding Model.
@@ -55,24 +63,44 @@
5563
*/
5664
public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
5765

66+
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
67+
5868
public final VertexAiTextEmbeddingOptions defaultOptions;
5969

6070
private final VertexAiEmbeddingConnectionDetails connectionDetails;
6171

6272
private final RetryTemplate retryTemplate;
6373

74+
/**
75+
* Observation registry used for instrumentation.
76+
*/
77+
private final ObservationRegistry observationRegistry;
78+
79+
/**
80+
* Conventions to use for generating observations.
81+
*/
82+
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
83+
6484
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
6585
VertexAiTextEmbeddingOptions defaultEmbeddingOptions) {
6686
this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
6787
}
6888

6989
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
7090
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
91+
this(connectionDetails, defaultEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP);
92+
}
93+
94+
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
95+
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate,
96+
ObservationRegistry observationRegistry) {
7197
Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null");
7298
Assert.notNull(retryTemplate, "retryTemplate must not be null");
99+
Assert.notNull(observationRegistry, "observationRegistry must not be null");
73100
this.defaultOptions = defaultEmbeddingOptions.initializeDefaults();
74101
this.connectionDetails = connectionDetails;
75102
this.retryTemplate = retryTemplate;
103+
this.observationRegistry = observationRegistry;
76104
}
77105

78106
@Override
@@ -83,42 +111,64 @@ public float[] embed(Document document) {
83111

84112
@Override
85113
public EmbeddingResponse call(EmbeddingRequest request) {
86-
return retryTemplate.execute(context -> {
87-
VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;
88114

89-
if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
90-
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
91-
finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
92-
VertexAiTextEmbeddingOptions.class);
93-
}
115+
final VertexAiTextEmbeddingOptions finalOptions = mergedOptions(request);
94116

95-
PredictionServiceClient client = createPredictionServiceClient();
117+
var observationContext = EmbeddingModelObservationContext.builder()
118+
.embeddingRequest(request)
119+
.provider(AiProvider.VERTEX_AI.value())
120+
.requestOptions(finalOptions)
121+
.build();
96122

97-
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
123+
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
124+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
125+
this.observationRegistry)
126+
.observe(() -> {
127+
PredictionServiceClient client = createPredictionServiceClient();
98128

99-
PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
100-
finalOptions);
129+
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
101130

102-
PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder);
131+
PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
132+
finalOptions);
103133

104-
int index = 0;
105-
int totalTokenCount = 0;
106-
List<Embedding> embeddingList = new ArrayList<>();
107-
for (Value prediction : embeddingResponse.getPredictionsList()) {
108-
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
109-
Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
110-
Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
111-
totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue();
134+
PredictResponse embeddingResponse = retryTemplate
135+
.execute(context -> getPredictResponse(client, predictRequestBuilder));
112136

113-
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
137+
int index = 0;
138+
int totalTokenCount = 0;
139+
List<Embedding> embeddingList = new ArrayList<>();
140+
for (Value prediction : embeddingResponse.getPredictionsList()) {
141+
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
142+
Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
143+
Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
144+
totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue();
114145

115-
float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);
146+
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
116147

117-
embeddingList.add(new Embedding(vectorValues, index++));
118-
}
119-
return new EmbeddingResponse(embeddingList,
120-
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
121-
});
148+
float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);
149+
150+
embeddingList.add(new Embedding(vectorValues, index++));
151+
}
152+
EmbeddingResponse response = new EmbeddingResponse(embeddingList,
153+
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
154+
155+
observationContext.setResponse(response);
156+
157+
return response;
158+
});
159+
}
160+
161+
private VertexAiTextEmbeddingOptions mergedOptions(EmbeddingRequest request) {
162+
163+
VertexAiTextEmbeddingOptions mergedOptions = this.defaultOptions;
164+
165+
if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
166+
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
167+
mergedOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
168+
VertexAiTextEmbeddingOptions.class);
169+
}
170+
171+
return mergedOptions;
122172
}
123173

124174
protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
@@ -183,4 +233,13 @@ public int dimensions() {
183233
.collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName,
184234
VertexAiTextEmbeddingModelName::getDimensions));
185235

236+
/**
237+
* Use the provided convention for reporting observation data
238+
* @param observationConvention The provided convention
239+
*/
240+
public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
241+
Assert.notNull(observationConvention, "observationConvention cannot be null");
242+
this.observationConvention = observationConvention;
243+
}
244+
186245
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.vertexai.embedding.text;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.util.List;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
24+
import org.springframework.ai.embedding.EmbeddingRequest;
25+
import org.springframework.ai.embedding.EmbeddingResponse;
26+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
27+
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
28+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames;
29+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames;
30+
import org.springframework.ai.observation.conventions.AiOperationType;
31+
import org.springframework.ai.observation.conventions.AiProvider;
32+
import org.springframework.ai.retry.RetryUtils;
33+
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
34+
import org.springframework.beans.factory.annotation.Autowired;
35+
import org.springframework.boot.SpringBootConfiguration;
36+
import org.springframework.boot.test.context.SpringBootTest;
37+
import org.springframework.context.annotation.Bean;
38+
39+
import io.micrometer.observation.ObservationRegistry;
40+
import io.micrometer.observation.tck.TestObservationRegistry;
41+
import io.micrometer.observation.tck.TestObservationRegistryAssert;
42+
43+
/**
44+
* Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}.
45+
*
46+
* @author Christian Tzolov
47+
*/
48+
@SpringBootTest(classes = VertexAiTextEmbeddingModelObservationIT.Config.class)
49+
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
50+
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
51+
public class VertexAiTextEmbeddingModelObservationIT {
52+
53+
@Autowired
54+
TestObservationRegistry observationRegistry;
55+
56+
@Autowired
57+
VertexAiTextEmbeddingModel embeddingModel;
58+
59+
@Test
60+
void observationForEmbeddingOperation() {
61+
62+
var options = VertexAiTextEmbeddingOptions.builder()
63+
.withModel(VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
64+
.withDimensions(768)
65+
.build();
66+
67+
EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options);
68+
69+
EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest);
70+
assertThat(embeddingResponse.getResults()).isNotEmpty();
71+
72+
EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata();
73+
assertThat(responseMetadata).isNotNull();
74+
75+
TestObservationRegistryAssert.assertThat(observationRegistry)
76+
.doesNotHaveAnyRemainingCurrentObservation()
77+
.hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME)
78+
.that()
79+
.hasContextualNameEqualTo("embedding " + VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
80+
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
81+
AiOperationType.EMBEDDING.value())
82+
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.VERTEX_AI.value())
83+
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(),
84+
VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName())
85+
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel())
86+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(), "768")
87+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
88+
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
89+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
90+
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
91+
.hasBeenStarted()
92+
.hasBeenStopped();
93+
}
94+
95+
@SpringBootConfiguration
96+
static class Config {
97+
98+
@Bean
99+
public TestObservationRegistry observationRegistry() {
100+
return TestObservationRegistry.create();
101+
}
102+
103+
@Bean
104+
public VertexAiEmbeddingConnectionDetails connectionDetails() {
105+
return VertexAiEmbeddingConnectionDetails.builder()
106+
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
107+
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
108+
.build();
109+
}
110+
111+
@Bean
112+
public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
113+
ObservationRegistry observationRegistry) {
114+
115+
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
116+
.withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME)
117+
.build();
118+
119+
return new VertexAiTextEmbeddingModel(connectionDetails, options, RetryUtils.DEFAULT_RETRY_TEMPLATE,
120+
observationRegistry);
121+
}
122+
123+
}
124+
125+
}

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
import java.io.IOException;
1919

2020
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
21+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
2122
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
2223
import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModel;
2324
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel;
25+
import org.springframework.beans.factory.ObjectProvider;
2426
import org.springframework.boot.autoconfigure.AutoConfiguration;
2527
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
2628
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
@@ -34,6 +36,8 @@
3436

3537
import com.google.cloud.vertexai.VertexAI;
3638

39+
import io.micrometer.observation.ObservationRegistry;
40+
3741
/**
3842
* Auto-configuration for Vertex AI Gemini Chat.
3943
*
@@ -73,9 +77,16 @@ public VertexAiEmbeddingConnectionDetails connectionDetails(
7377
@ConditionalOnProperty(prefix = VertexAiTextEmbeddingProperties.CONFIG_PREFIX, name = "enabled",
7478
havingValue = "true", matchIfMissing = true)
7579
public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddingConnectionDetails connectionDetails,
76-
VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate) {
80+
VertexAiTextEmbeddingProperties textEmbeddingProperties, RetryTemplate retryTemplate,
81+
ObjectProvider<ObservationRegistry> observationRegistry,
82+
ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
83+
84+
var embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(),
85+
retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));
86+
87+
observationConvention.ifAvailable(embeddingModel::setObservationConvention);
7788

78-
return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions(), retryTemplate);
89+
return embeddingModel;
7990
}
8091

8192
@Bean

0 commit comments

Comments
 (0)