diff --git a/models/spring-ai-transformers/pom.xml b/models/spring-ai-transformers/pom.xml
index 65f075aface..086e0064f15 100644
--- a/models/spring-ai-transformers/pom.xml
+++ b/models/spring-ai-transformers/pom.xml
@@ -85,6 +85,12 @@
test
+
+ io.micrometer
+ micrometer-observation-test
+ test
+
+
diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java
index e666fbdf225..d86460526e2 100644
--- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java
+++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java
@@ -23,34 +23,40 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
-import ai.djl.huggingface.tokenizers.Encoding;
-import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
-import ai.djl.modality.nlp.preprocess.Tokenizer;
-import ai.djl.ndarray.NDArray;
-import ai.djl.ndarray.NDManager;
-import ai.djl.ndarray.types.DataType;
-import ai.djl.ndarray.types.Shape;
-import ai.onnxruntime.OnnxTensor;
-import ai.onnxruntime.OnnxValue;
-import ai.onnxruntime.OrtEnvironment;
-import ai.onnxruntime.OrtException;
-import ai.onnxruntime.OrtSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
-import org.springframework.ai.embedding.EmbeddingOptions;
+import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
+import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.modality.nlp.preprocess.Tokenizer;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.DataType;
+import ai.djl.ndarray.types.Shape;
+import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OnnxValue;
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import io.micrometer.observation.ObservationRegistry;
+
/**
* https://www.sbert.net/index.html https://www.sbert.net/docs/pretrained_models.html
*
@@ -60,6 +66,8 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement
private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class);
+ private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
+
// ONNX tokenizer for the all-MiniLM-L6-v2 generative
public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
@@ -126,13 +134,29 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement
private Set onnxModelInputs;
+ /**
+ * Observation registry used for instrumentation.
+ */
+ private final ObservationRegistry observationRegistry;
+
+ /**
+ * Conventions to use for generating observations.
+ */
+ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
+
public TransformersEmbeddingModel() {
this(MetadataMode.NONE);
}
public TransformersEmbeddingModel(MetadataMode metadataMode) {
+ this(metadataMode, ObservationRegistry.NOOP);
+ }
+
+ public TransformersEmbeddingModel(MetadataMode metadataMode, ObservationRegistry observationRegistry) {
Assert.notNull(metadataMode, "Metadata mode should not be null");
+ Assert.notNull(observationRegistry, "Observation registry should not be null");
this.metadataMode = metadataMode;
+ this.observationRegistry = observationRegistry;
}
public void setTokenizerOptions(Map tokenizerOptions) {
@@ -231,7 +255,7 @@ public EmbeddingResponse embedForResponse(List texts) {
@Override
public List embed(List texts) {
- return this.call(new EmbeddingRequest(texts, EmbeddingOptions.EMPTY))
+ return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
.getResults()
.stream()
.map(e -> e.getOutput())
@@ -241,63 +265,79 @@ public List embed(List texts) {
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
- List resultEmbeddings = new ArrayList<>();
+ var observationContext = EmbeddingModelObservationContext.builder()
+ .embeddingRequest(request)
+ .provider(AiProvider.ONNX.value())
+ .requestOptions(request.getOptions())
+ .build();
- try {
+ return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
+ .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
+ this.observationRegistry)
+ .observe(() -> {
+ List resultEmbeddings = new ArrayList<>();
- Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions());
+ try {
- long[][] input_ids0 = new long[encodings.length][];
- long[][] attention_mask0 = new long[encodings.length][];
- long[][] token_type_ids0 = new long[encodings.length][];
+ Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions());
- for (int i = 0; i < encodings.length; i++) {
- input_ids0[i] = encodings[i].getIds();
- attention_mask0[i] = encodings[i].getAttentionMask();
- token_type_ids0[i] = encodings[i].getTypeIds();
- }
+ long[][] input_ids0 = new long[encodings.length][];
+ long[][] attention_mask0 = new long[encodings.length][];
+ long[][] token_type_ids0 = new long[encodings.length][];
- OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
- OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
- OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
+ for (int i = 0; i < encodings.length; i++) {
+ input_ids0[i] = encodings[i].getIds();
+ attention_mask0[i] = encodings[i].getAttentionMask();
+ token_type_ids0[i] = encodings[i].getTypeIds();
+ }
+
+ OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
+ OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
+ OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
- Map modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
- "token_type_ids", tokenTypeIds);
+ Map modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
+ "token_type_ids", tokenTypeIds);
- modelInputs = removeUnknownModelInputs(modelInputs);
+ modelInputs = removeUnknownModelInputs(modelInputs);
- // The Run result object is AutoCloseable to prevent references from leaking
- // out. Once the Result object is
- // closed, all it’s child OnnxValues are closed too.
- try (OrtSession.Result results = this.session.run(modelInputs)) {
+ // The Run result object is AutoCloseable to prevent references from
+ // leaking
+ // out. Once the Result object is
+ // closed, all it’s child OnnxValues are closed too.
+ try (OrtSession.Result results = this.session.run(modelInputs)) {
- // OnnxValue lastHiddenState = results.get(0);
- OnnxValue lastHiddenState = results.get(this.modelOutputName).get();
+ // OnnxValue lastHiddenState = results.get(0);
+ OnnxValue lastHiddenState = results.get(this.modelOutputName).get();
- // 0 - batch_size (1..x)
- // 1 - sequence_length (128)
- // 2 - embedding dimensions (384)
- float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
+ // 0 - batch_size (1..x)
+ // 1 - sequence_length (128)
+ // 2 - embedding dimensions (384)
+ float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
- try (NDManager manager = NDManager.newBaseManager()) {
- NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
- NDArray ndAttentionMask = manager.create(attention_mask0);
+ try (NDManager manager = NDManager.newBaseManager()) {
+ NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
+ NDArray ndAttentionMask = manager.create(attention_mask0);
- NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
+ NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
- for (int i = 0; i < embedding.size(0); i++) {
- resultEmbeddings.add(embedding.get(i).toFloatArray());
+ for (int i = 0; i < embedding.size(0); i++) {
+ resultEmbeddings.add(embedding.get(i).toFloatArray());
+ }
+ }
}
}
- }
- }
- catch (OrtException ex) {
- throw new RuntimeException(ex);
- }
+ catch (OrtException ex) {
+ throw new RuntimeException(ex);
+ }
- var indexCounter = new AtomicInteger(0);
- return new EmbeddingResponse(
- resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList());
+ var indexCounter = new AtomicInteger(0);
+
+ EmbeddingResponse embeddingResponse = new EmbeddingResponse(
+ resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList());
+ observationContext.setResponse(embeddingResponse);
+
+ return embeddingResponse;
+ });
}
private Map removeUnknownModelInputs(Map modelInputs) {
@@ -347,4 +387,13 @@ private static Resource toResource(String uri) {
return new DefaultResourceLoader().getResource(uri);
}
+ /**
+ * Use the provided convention for reporting observation data
+ * @param observationConvention The provided convention
+ */
+ public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
+ Assert.notNull(observationConvention, "observationConvention cannot be null");
+ this.observationConvention = observationConvention;
+ }
+
}
\ No newline at end of file
diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java
new file mode 100644
index 00000000000..ec3c9c5ad50
--- /dev/null
+++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.transformers;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.List;
+
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.document.MetadataMode;
+import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
+import org.springframework.ai.embedding.EmbeddingRequest;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.embedding.EmbeddingResponseMetadata;
+import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames;
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames;
+import org.springframework.ai.observation.conventions.AiOperationType;
+import org.springframework.ai.observation.conventions.AiProvider;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.context.annotation.Bean;
+
+import io.micrometer.observation.tck.TestObservationRegistry;
+import io.micrometer.observation.tck.TestObservationRegistryAssert;
+
+/**
+ * Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}.
+ *
+ * @author Christian Tzolov
+ */
+@SpringBootTest(classes = TransformersEmbeddingModelObservationTests.Config.class)
+public class TransformersEmbeddingModelObservationTests {
+
+ @Autowired
+ TestObservationRegistry observationRegistry;
+
+ @Autowired
+ TransformersEmbeddingModel embeddingModel;
+
+ @Test
+ void observationForEmbeddingOperation() {
+
+ var options = EmbeddingOptionsBuilder.builder().withModel("bert-base-uncased").build();
+
+ EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options);
+
+ EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest);
+ assertThat(embeddingResponse.getResults()).isNotEmpty();
+
+ EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata();
+ assertThat(responseMetadata).isNotNull();
+
+ TestObservationRegistryAssert.assertThat(observationRegistry)
+ .doesNotHaveAnyRemainingCurrentObservation()
+ .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME)
+ .that()
+ .hasContextualNameEqualTo("embedding " + "bert-base-uncased")
+ .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
+ AiOperationType.EMBEDDING.value())
+ .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.ONNX.value())
+ .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "bert-base-uncased")
+ // .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
+ // responseMetadata.getModel())
+ // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(),
+ // "1536")
+ .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
+ String.valueOf(responseMetadata.getUsage().getPromptTokens()))
+ .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
+ String.valueOf(responseMetadata.getUsage().getTotalTokens()))
+ .hasBeenStarted()
+ .hasBeenStopped();
+ }
+
+ @SpringBootConfiguration
+ static class Config {
+
+ @Bean
+ public TestObservationRegistry observationRegistry() {
+ return TestObservationRegistry.create();
+ }
+
+ @Bean
+ public TransformersEmbeddingModel openAiEmbeddingModel(TestObservationRegistry observationRegistry) {
+ return new TransformersEmbeddingModel(MetadataMode.NONE, observationRegistry);
+ }
+
+ }
+
+}
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java
index 3ff73b98b48..aff11d6af03 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java
@@ -39,7 +39,8 @@ public enum AiProvider {
MINIMAX("minimax"),
MOONSHOT("moonshot"),
SPRING_AI("spring_ai"),
- VERTEX_AI("vertex_ai");
+ VERTEX_AI("vertex_ai"),
+ ONNX("onnx");
private final String value;
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java
index 2964ffb297e..583d568631c 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java
@@ -15,10 +15,9 @@
*/
package org.springframework.ai.autoconfigure.transformers;
-import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
-import ai.onnxruntime.OrtSession;
-
+import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
+import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -26,6 +25,10 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.onnxruntime.OrtSession;
+import io.micrometer.observation.ObservationRegistry;
+
/**
* @author Christian Tzolov
*/
@@ -38,9 +41,12 @@ public class TransformersEmbeddingModelAutoConfiguration {
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = TransformersEmbeddingModelProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
- public TransformersEmbeddingModel embeddingModel(TransformersEmbeddingModelProperties properties) {
+ public TransformersEmbeddingModel embeddingModel(TransformersEmbeddingModelProperties properties,
+ ObjectProvider observationRegistry,
+ ObjectProvider observationConvention) {
- TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel(properties.getMetadataMode());
+ TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel(properties.getMetadataMode(),
+ observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));
embeddingModel.setDisableCaching(!properties.getCache().isEnabled());
embeddingModel.setResourceCacheDirectory(properties.getCache().getDirectory());
@@ -54,6 +60,8 @@ public TransformersEmbeddingModel embeddingModel(TransformersEmbeddingModelPrope
embeddingModel.setModelOutputName(properties.getOnnx().getModelOutputName());
+ observationConvention.ifAvailable(embeddingModel::setObservationConvention);
+
return embeddingModel;
}