2323import java .util .concurrent .atomic .AtomicInteger ;
2424import java .util .stream .Collectors ;
2525
26- import ai .djl .huggingface .tokenizers .Encoding ;
27- import ai .djl .huggingface .tokenizers .HuggingFaceTokenizer ;
28- import ai .djl .modality .nlp .preprocess .Tokenizer ;
29- import ai .djl .ndarray .NDArray ;
30- import ai .djl .ndarray .NDManager ;
31- import ai .djl .ndarray .types .DataType ;
32- import ai .djl .ndarray .types .Shape ;
33- import ai .onnxruntime .OnnxTensor ;
34- import ai .onnxruntime .OnnxValue ;
35- import ai .onnxruntime .OrtEnvironment ;
36- import ai .onnxruntime .OrtException ;
37- import ai .onnxruntime .OrtSession ;
3826import org .apache .commons .logging .Log ;
3927import org .apache .commons .logging .LogFactory ;
40-
4128import org .springframework .ai .document .Document ;
4229import org .springframework .ai .document .MetadataMode ;
4330import org .springframework .ai .embedding .AbstractEmbeddingModel ;
4431import org .springframework .ai .embedding .Embedding ;
45- import org .springframework .ai .embedding .EmbeddingOptions ;
32+ import org .springframework .ai .embedding .EmbeddingOptionsBuilder ;
4633import org .springframework .ai .embedding .EmbeddingRequest ;
4734import org .springframework .ai .embedding .EmbeddingResponse ;
35+ import org .springframework .ai .embedding .observation .DefaultEmbeddingModelObservationConvention ;
36+ import org .springframework .ai .embedding .observation .EmbeddingModelObservationContext ;
37+ import org .springframework .ai .embedding .observation .EmbeddingModelObservationConvention ;
38+ import org .springframework .ai .embedding .observation .EmbeddingModelObservationDocumentation ;
39+ import org .springframework .ai .observation .conventions .AiProvider ;
4840import org .springframework .beans .factory .InitializingBean ;
4941import org .springframework .core .io .DefaultResourceLoader ;
5042import org .springframework .core .io .Resource ;
5143import org .springframework .util .Assert ;
5244import org .springframework .util .StringUtils ;
5345
46+ import ai .djl .huggingface .tokenizers .Encoding ;
47+ import ai .djl .huggingface .tokenizers .HuggingFaceTokenizer ;
48+ import ai .djl .modality .nlp .preprocess .Tokenizer ;
49+ import ai .djl .ndarray .NDArray ;
50+ import ai .djl .ndarray .NDManager ;
51+ import ai .djl .ndarray .types .DataType ;
52+ import ai .djl .ndarray .types .Shape ;
53+ import ai .onnxruntime .OnnxTensor ;
54+ import ai .onnxruntime .OnnxValue ;
55+ import ai .onnxruntime .OrtEnvironment ;
56+ import ai .onnxruntime .OrtException ;
57+ import ai .onnxruntime .OrtSession ;
58+ import io .micrometer .observation .ObservationRegistry ;
59+
5460/**
5561 * https://www.sbert.net/index.html https://www.sbert.net/docs/pretrained_models.html
5662 *
@@ -60,6 +66,8 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement
6066
6167 private static final Log logger = LogFactory .getLog (TransformersEmbeddingModel .class );
6268
69+ private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention ();
70+
6371 // ONNX tokenizer for the all-MiniLM-L6-v2 generative
6472 public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json" ;
6573
@@ -126,13 +134,29 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement
126134
127135 private Set <String > onnxModelInputs ;
128136
137+ /**
138+ * Observation registry used for instrumentation.
139+ */
140+ private final ObservationRegistry observationRegistry ;
141+
142+ /**
143+ * Conventions to use for generating observations.
144+ */
145+ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION ;
146+
129147 public TransformersEmbeddingModel () {
130148 this (MetadataMode .NONE );
131149 }
132150
133151 public TransformersEmbeddingModel (MetadataMode metadataMode ) {
152+ this (metadataMode , ObservationRegistry .NOOP );
153+ }
154+
155+ public TransformersEmbeddingModel (MetadataMode metadataMode , ObservationRegistry observationRegistry ) {
134156 Assert .notNull (metadataMode , "Metadata mode should not be null" );
157+ Assert .notNull (observationRegistry , "Observation registry should not be null" );
135158 this .metadataMode = metadataMode ;
159+ this .observationRegistry = observationRegistry ;
136160 }
137161
138162 public void setTokenizerOptions (Map <String , String > tokenizerOptions ) {
@@ -231,7 +255,7 @@ public EmbeddingResponse embedForResponse(List<String> texts) {
231255
232256 @ Override
233257 public List <float []> embed (List <String > texts ) {
234- return this .call (new EmbeddingRequest (texts , EmbeddingOptions . EMPTY ))
258+ return this .call (new EmbeddingRequest (texts , EmbeddingOptionsBuilder . builder (). build () ))
235259 .getResults ()
236260 .stream ()
237261 .map (e -> e .getOutput ())
@@ -241,63 +265,79 @@ public List<float[]> embed(List<String> texts) {
241265 @ Override
242266 public EmbeddingResponse call (EmbeddingRequest request ) {
243267
244- List <float []> resultEmbeddings = new ArrayList <>();
268+ var observationContext = EmbeddingModelObservationContext .builder ()
269+ .embeddingRequest (request )
270+ .provider (AiProvider .ONNX .value ())
271+ .requestOptions (request .getOptions ())
272+ .build ();
245273
246- try {
274+ return EmbeddingModelObservationDocumentation .EMBEDDING_MODEL_OPERATION
275+ .observation (this .observationConvention , DEFAULT_OBSERVATION_CONVENTION , () -> observationContext ,
276+ this .observationRegistry )
277+ .observe (() -> {
278+ List <float []> resultEmbeddings = new ArrayList <>();
247279
248- Encoding [] encodings = this . tokenizer . batchEncode ( request . getInstructions ());
280+ try {
249281
250- long [][] input_ids0 = new long [encodings .length ][];
251- long [][] attention_mask0 = new long [encodings .length ][];
252- long [][] token_type_ids0 = new long [encodings .length ][];
282+ Encoding [] encodings = this .tokenizer .batchEncode (request .getInstructions ());
253283
254- for (int i = 0 ; i < encodings .length ; i ++) {
255- input_ids0 [i ] = encodings [i ].getIds ();
256- attention_mask0 [i ] = encodings [i ].getAttentionMask ();
257- token_type_ids0 [i ] = encodings [i ].getTypeIds ();
258- }
284+ long [][] input_ids0 = new long [encodings .length ][];
285+ long [][] attention_mask0 = new long [encodings .length ][];
286+ long [][] token_type_ids0 = new long [encodings .length ][];
259287
260- OnnxTensor inputIds = OnnxTensor .createTensor (this .environment , input_ids0 );
261- OnnxTensor attentionMask = OnnxTensor .createTensor (this .environment , attention_mask0 );
262- OnnxTensor tokenTypeIds = OnnxTensor .createTensor (this .environment , token_type_ids0 );
288+ for (int i = 0 ; i < encodings .length ; i ++) {
289+ input_ids0 [i ] = encodings [i ].getIds ();
290+ attention_mask0 [i ] = encodings [i ].getAttentionMask ();
291+ token_type_ids0 [i ] = encodings [i ].getTypeIds ();
292+ }
293+
294+ OnnxTensor inputIds = OnnxTensor .createTensor (this .environment , input_ids0 );
295+ OnnxTensor attentionMask = OnnxTensor .createTensor (this .environment , attention_mask0 );
296+ OnnxTensor tokenTypeIds = OnnxTensor .createTensor (this .environment , token_type_ids0 );
263297
264- Map <String , OnnxTensor > modelInputs = Map .of ("input_ids" , inputIds , "attention_mask" , attentionMask ,
265- "token_type_ids" , tokenTypeIds );
298+ Map <String , OnnxTensor > modelInputs = Map .of ("input_ids" , inputIds , "attention_mask" , attentionMask ,
299+ "token_type_ids" , tokenTypeIds );
266300
267- modelInputs = removeUnknownModelInputs (modelInputs );
301+ modelInputs = removeUnknownModelInputs (modelInputs );
268302
269- // The Run result object is AutoCloseable to prevent references from leaking
270- // out. Once the Result object is
271- // closed, all it’s child OnnxValues are closed too.
272- try (OrtSession .Result results = this .session .run (modelInputs )) {
303+ // The Run result object is AutoCloseable to prevent references from
304+ // leaking
305+ // out. Once the Result object is
306+ // closed, all it’s child OnnxValues are closed too.
307+ try (OrtSession .Result results = this .session .run (modelInputs )) {
273308
274- // OnnxValue lastHiddenState = results.get(0);
275- OnnxValue lastHiddenState = results .get (this .modelOutputName ).get ();
309+ // OnnxValue lastHiddenState = results.get(0);
310+ OnnxValue lastHiddenState = results .get (this .modelOutputName ).get ();
276311
277- // 0 - batch_size (1..x)
278- // 1 - sequence_length (128)
279- // 2 - embedding dimensions (384)
280- float [][][] tokenEmbeddings = (float [][][]) lastHiddenState .getValue ();
312+ // 0 - batch_size (1..x)
313+ // 1 - sequence_length (128)
314+ // 2 - embedding dimensions (384)
315+ float [][][] tokenEmbeddings = (float [][][]) lastHiddenState .getValue ();
281316
282- try (NDManager manager = NDManager .newBaseManager ()) {
283- NDArray ndTokenEmbeddings = create (tokenEmbeddings , manager );
284- NDArray ndAttentionMask = manager .create (attention_mask0 );
317+ try (NDManager manager = NDManager .newBaseManager ()) {
318+ NDArray ndTokenEmbeddings = create (tokenEmbeddings , manager );
319+ NDArray ndAttentionMask = manager .create (attention_mask0 );
285320
286- NDArray embedding = meanPooling (ndTokenEmbeddings , ndAttentionMask );
321+ NDArray embedding = meanPooling (ndTokenEmbeddings , ndAttentionMask );
287322
288- for (int i = 0 ; i < embedding .size (0 ); i ++) {
289- resultEmbeddings .add (embedding .get (i ).toFloatArray ());
323+ for (int i = 0 ; i < embedding .size (0 ); i ++) {
324+ resultEmbeddings .add (embedding .get (i ).toFloatArray ());
325+ }
326+ }
290327 }
291328 }
292- }
293- }
294- catch (OrtException ex ) {
295- throw new RuntimeException (ex );
296- }
329+ catch (OrtException ex ) {
330+ throw new RuntimeException (ex );
331+ }
297332
298- var indexCounter = new AtomicInteger (0 );
299- return new EmbeddingResponse (
300- resultEmbeddings .stream ().map (e -> new Embedding (e , indexCounter .incrementAndGet ())).toList ());
333+ var indexCounter = new AtomicInteger (0 );
334+
335+ EmbeddingResponse embeddingResponse = new EmbeddingResponse (
336+ resultEmbeddings .stream ().map (e -> new Embedding (e , indexCounter .incrementAndGet ())).toList ());
337+ observationContext .setResponse (embeddingResponse );
338+
339+ return embeddingResponse ;
340+ });
301341 }
302342
303343 private Map <String , OnnxTensor > removeUnknownModelInputs (Map <String , OnnxTensor > modelInputs ) {
@@ -347,4 +387,13 @@ private static Resource toResource(String uri) {
347387 return new DefaultResourceLoader ().getResource (uri );
348388 }
349389
390+ /**
391+ * Use the provided convention for reporting observation data
392+ * @param observationConvention The provided convention
393+ */
394+ public void setObservationConvention (EmbeddingModelObservationConvention observationConvention ) {
395+ Assert .notNull (observationConvention , "observationConvention cannot be null" );
396+ this .observationConvention = observationConvention ;
397+ }
398+
350399}
0 commit comments