1515 */
1616package org .springframework .ai .vertexai .embedding .text ;
1717
18- import com .google .cloud .aiplatform .v1 .EndpointName ;
19- import com .google .cloud .aiplatform .v1 .PredictRequest ;
20- import com .google .cloud .aiplatform .v1 .PredictResponse ;
21- import com .google .cloud .aiplatform .v1 .PredictionServiceClient ;
22- import com .google .protobuf .Value ;
18+ import java .io .IOException ;
19+ import java .util .ArrayList ;
20+ import java .util .List ;
21+ import java .util .Map ;
22+ import java .util .stream .Collectors ;
23+ import java .util .stream .Stream ;
24+
2325import org .springframework .ai .chat .metadata .Usage ;
2426import org .springframework .ai .document .Document ;
2527import org .springframework .ai .embedding .AbstractEmbeddingModel ;
2830import org .springframework .ai .embedding .EmbeddingRequest ;
2931import org .springframework .ai .embedding .EmbeddingResponse ;
3032import org .springframework .ai .embedding .EmbeddingResponseMetadata ;
33+ import org .springframework .ai .embedding .observation .DefaultEmbeddingModelObservationConvention ;
34+ import org .springframework .ai .embedding .observation .EmbeddingModelObservationContext ;
35+ import org .springframework .ai .embedding .observation .EmbeddingModelObservationConvention ;
36+ import org .springframework .ai .embedding .observation .EmbeddingModelObservationDocumentation ;
3137import org .springframework .ai .model .ModelOptionsUtils ;
38+ import org .springframework .ai .observation .conventions .AiProvider ;
3239import org .springframework .ai .retry .RetryUtils ;
3340import org .springframework .ai .vertexai .embedding .VertexAiEmbeddingConnectionDetails ;
3441import org .springframework .ai .vertexai .embedding .VertexAiEmbeddingUsage ;
3946import org .springframework .util .Assert ;
4047import org .springframework .util .StringUtils ;
4148
42- import java .io .IOException ;
43- import java .util .ArrayList ;
44- import java .util .List ;
45- import java .util .Map ;
46- import java .util .stream .Collectors ;
47- import java .util .stream .Stream ;
49+ import com .google .cloud .aiplatform .v1 .EndpointName ;
50+ import com .google .cloud .aiplatform .v1 .PredictRequest ;
51+ import com .google .cloud .aiplatform .v1 .PredictResponse ;
52+ import com .google .cloud .aiplatform .v1 .PredictionServiceClient ;
53+ import com .google .protobuf .Value ;
54+
55+ import io .micrometer .observation .ObservationRegistry ;
4856
4957/**
5058 * A class representing a Vertex AI Text Embedding Model.
5563 */
5664public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
5765
66+ private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention ();
67+
5868 public final VertexAiTextEmbeddingOptions defaultOptions ;
5969
6070 private final VertexAiEmbeddingConnectionDetails connectionDetails ;
6171
6272 private final RetryTemplate retryTemplate ;
6373
74+ /**
75+ * Observation registry used for instrumentation.
76+ */
77+ private final ObservationRegistry observationRegistry ;
78+
79+ /**
80+ * Conventions to use for generating observations.
81+ */
82+ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION ;
83+
6484 public VertexAiTextEmbeddingModel (VertexAiEmbeddingConnectionDetails connectionDetails ,
6585 VertexAiTextEmbeddingOptions defaultEmbeddingOptions ) {
6686 this (connectionDetails , defaultEmbeddingOptions , RetryUtils .DEFAULT_RETRY_TEMPLATE );
6787 }
6888
6989 public VertexAiTextEmbeddingModel (VertexAiEmbeddingConnectionDetails connectionDetails ,
7090 VertexAiTextEmbeddingOptions defaultEmbeddingOptions , RetryTemplate retryTemplate ) {
91+ this (connectionDetails , defaultEmbeddingOptions , retryTemplate , ObservationRegistry .NOOP );
92+ }
93+
94+ public VertexAiTextEmbeddingModel (VertexAiEmbeddingConnectionDetails connectionDetails ,
95+ VertexAiTextEmbeddingOptions defaultEmbeddingOptions , RetryTemplate retryTemplate ,
96+ ObservationRegistry observationRegistry ) {
7197 Assert .notNull (defaultEmbeddingOptions , "VertexAiTextEmbeddingOptions must not be null" );
7298 Assert .notNull (retryTemplate , "retryTemplate must not be null" );
99+ Assert .notNull (observationRegistry , "observationRegistry must not be null" );
73100 this .defaultOptions = defaultEmbeddingOptions .initializeDefaults ();
74101 this .connectionDetails = connectionDetails ;
75102 this .retryTemplate = retryTemplate ;
103+ this .observationRegistry = observationRegistry ;
76104 }
77105
78106 @ Override
@@ -83,42 +111,64 @@ public float[] embed(Document document) {
83111
84112 @ Override
85113 public EmbeddingResponse call (EmbeddingRequest request ) {
86- return retryTemplate .execute (context -> {
87- VertexAiTextEmbeddingOptions finalOptions = this .defaultOptions ;
88114
89- if (request .getOptions () != null && request .getOptions () != EmbeddingOptions .EMPTY ) {
90- var defaultOptionsCopy = VertexAiTextEmbeddingOptions .builder ().from (this .defaultOptions ).build ();
91- finalOptions = ModelOptionsUtils .merge (request .getOptions (), defaultOptionsCopy ,
92- VertexAiTextEmbeddingOptions .class );
93- }
115+ final VertexAiTextEmbeddingOptions finalOptions = mergedOptions (request );
94116
95- PredictionServiceClient client = createPredictionServiceClient ();
117+ var observationContext = EmbeddingModelObservationContext .builder ()
118+ .embeddingRequest (request )
119+ .provider (AiProvider .VERTEX_AI .value ())
120+ .requestOptions (finalOptions )
121+ .build ();
96122
97- EndpointName endpointName = this .connectionDetails .getEndpointName (finalOptions .getModel ());
123+ return EmbeddingModelObservationDocumentation .EMBEDDING_MODEL_OPERATION
124+ .observation (this .observationConvention , DEFAULT_OBSERVATION_CONVENTION , () -> observationContext ,
125+ this .observationRegistry )
126+ .observe (() -> {
127+ PredictionServiceClient client = createPredictionServiceClient ();
98128
99- PredictRequest .Builder predictRequestBuilder = getPredictRequestBuilder (request , endpointName ,
100- finalOptions );
129+ EndpointName endpointName = this .connectionDetails .getEndpointName (finalOptions .getModel ());
101130
102- PredictResponse embeddingResponse = getPredictResponse (client , predictRequestBuilder );
131+ PredictRequest .Builder predictRequestBuilder = getPredictRequestBuilder (request , endpointName ,
132+ finalOptions );
103133
104- int index = 0 ;
105- int totalTokenCount = 0 ;
106- List <Embedding > embeddingList = new ArrayList <>();
107- for (Value prediction : embeddingResponse .getPredictionsList ()) {
108- Value embeddings = prediction .getStructValue ().getFieldsOrThrow ("embeddings" );
109- Value statistics = embeddings .getStructValue ().getFieldsOrThrow ("statistics" );
110- Value tokenCount = statistics .getStructValue ().getFieldsOrThrow ("token_count" );
111- totalTokenCount = totalTokenCount + (int ) tokenCount .getNumberValue ();
134+ PredictResponse embeddingResponse = retryTemplate
135+ .execute (context -> getPredictResponse (client , predictRequestBuilder ));
112136
113- Value values = embeddings .getStructValue ().getFieldsOrThrow ("values" );
137+ int index = 0 ;
138+ int totalTokenCount = 0 ;
139+ List <Embedding > embeddingList = new ArrayList <>();
140+ for (Value prediction : embeddingResponse .getPredictionsList ()) {
141+ Value embeddings = prediction .getStructValue ().getFieldsOrThrow ("embeddings" );
142+ Value statistics = embeddings .getStructValue ().getFieldsOrThrow ("statistics" );
143+ Value tokenCount = statistics .getStructValue ().getFieldsOrThrow ("token_count" );
144+ totalTokenCount = totalTokenCount + (int ) tokenCount .getNumberValue ();
114145
115- float [] vectorValues = VertexAiEmbeddingUtils . toVector ( values );
146+ Value values = embeddings . getStructValue (). getFieldsOrThrow ( " values" );
116147
117- embeddingList .add (new Embedding (vectorValues , index ++));
118- }
119- return new EmbeddingResponse (embeddingList ,
120- generateResponseMetadata (finalOptions .getModel (), totalTokenCount ));
121- });
148+ float [] vectorValues = VertexAiEmbeddingUtils .toVector (values );
149+
150+ embeddingList .add (new Embedding (vectorValues , index ++));
151+ }
152+ EmbeddingResponse response = new EmbeddingResponse (embeddingList ,
153+ generateResponseMetadata (finalOptions .getModel (), totalTokenCount ));
154+
155+ observationContext .setResponse (response );
156+
157+ return response ;
158+ });
159+ }
160+
161+ private VertexAiTextEmbeddingOptions mergedOptions (EmbeddingRequest request ) {
162+
163+ VertexAiTextEmbeddingOptions mergedOptions = this .defaultOptions ;
164+
165+ if (request .getOptions () != null && request .getOptions () != EmbeddingOptions .EMPTY ) {
166+ var defaultOptionsCopy = VertexAiTextEmbeddingOptions .builder ().from (this .defaultOptions ).build ();
167+ mergedOptions = ModelOptionsUtils .merge (request .getOptions (), defaultOptionsCopy ,
168+ VertexAiTextEmbeddingOptions .class );
169+ }
170+
171+ return mergedOptions ;
122172 }
123173
124174 protected PredictRequest .Builder getPredictRequestBuilder (EmbeddingRequest request , EndpointName endpointName ,
@@ -183,4 +233,13 @@ public int dimensions() {
183233 .collect (Collectors .toMap (VertexAiTextEmbeddingModelName ::getName ,
184234 VertexAiTextEmbeddingModelName ::getDimensions ));
185235
236+ /**
237+ * Use the provided convention for reporting observation data
238+ * @param observationConvention The provided convention
239+ */
240+ public void setObservationConvention (EmbeddingModelObservationConvention observationConvention ) {
241+ Assert .notNull (observationConvention , "observationConvention cannot be null" );
242+ this .observationConvention = observationConvention ;
243+ }
244+
186245}
0 commit comments