From b6165c6a30391565a3babc05add78a97c9a7775c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 8 Oct 2024 08:38:03 +0200 Subject: [PATCH] Add observability support to TransformersEmbeddingModel - Integrate ObservationRegistry and EmbeddingModelObservationConvention - Update TransformersEmbeddingModel to use observations - Add TransformersEmbeddingModelObservationTests - Update TransformersEmbeddingModelAutoConfiguration for observation support - Add ONNX to AiProvider enum --- models/spring-ai-transformers/pom.xml | 6 + .../TransformersEmbeddingModel.java | 161 ++++++++++++------ ...formersEmbeddingModelObservationTests.java | 104 +++++++++++ .../observation/conventions/AiProvider.java | 3 +- ...ormersEmbeddingModelAutoConfiguration.java | 18 +- 5 files changed, 230 insertions(+), 62 deletions(-) create mode 100644 models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java diff --git a/models/spring-ai-transformers/pom.xml b/models/spring-ai-transformers/pom.xml index 65f075aface..086e0064f15 100644 --- a/models/spring-ai-transformers/pom.xml +++ b/models/spring-ai-transformers/pom.xml @@ -85,6 +85,12 @@ test + + io.micrometer + micrometer-observation-test + test + + diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java index e666fbdf225..d86460526e2 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java @@ -23,34 +23,40 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; -import ai.djl.huggingface.tokenizers.Encoding; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import ai.djl.modality.nlp.preprocess.Tokenizer; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.DataType; -import ai.djl.ndarray.types.Shape; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; - import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; import org.springframework.ai.embedding.Embedding; -import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; +import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.beans.factory.InitializingBean; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.modality.nlp.preprocess.Tokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import io.micrometer.observation.ObservationRegistry; + /** * https://www.sbert.net/index.html https://www.sbert.net/docs/pretrained_models.html * @@ -60,6 +66,8 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class); + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + // ONNX tokenizer for the all-MiniLM-L6-v2 generative public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; @@ -126,13 +134,29 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement private Set onnxModelInputs; + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + public TransformersEmbeddingModel() { this(MetadataMode.NONE); } public TransformersEmbeddingModel(MetadataMode metadataMode) { + this(metadataMode, ObservationRegistry.NOOP); + } + + public TransformersEmbeddingModel(MetadataMode metadataMode, ObservationRegistry observationRegistry) { Assert.notNull(metadataMode, "Metadata mode should not be null"); + Assert.notNull(observationRegistry, "Observation registry should not be null"); this.metadataMode = metadataMode; + this.observationRegistry = observationRegistry; } public void setTokenizerOptions(Map tokenizerOptions) { @@ -231,7 +255,7 @@ public EmbeddingResponse embedForResponse(List texts) { @Override public List embed(List texts) { - return this.call(new EmbeddingRequest(texts, EmbeddingOptions.EMPTY)) + return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build())) .getResults() .stream() .map(e -> e.getOutput()) @@ -241,63 +265,79 @@ public List embed(List texts) { @Override public EmbeddingResponse call(EmbeddingRequest request) { - List resultEmbeddings = new ArrayList<>(); + var observationContext = EmbeddingModelObservationContext.builder() + .embeddingRequest(request) + .provider(AiProvider.ONNX.value()) + .requestOptions(request.getOptions()) + .build(); - try { + return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + List resultEmbeddings = new ArrayList<>(); - Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions()); + try { - long[][] input_ids0 = new long[encodings.length][]; - long[][] attention_mask0 = new long[encodings.length][]; - long[][] token_type_ids0 = new long[encodings.length][]; + Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions()); - for (int i = 0; i < encodings.length; i++) { - input_ids0[i] = encodings[i].getIds(); - attention_mask0[i] = encodings[i].getAttentionMask(); - token_type_ids0[i] = encodings[i].getTypeIds(); - } + long[][] input_ids0 = new long[encodings.length][]; + long[][] attention_mask0 = new long[encodings.length][]; + long[][] token_type_ids0 = new long[encodings.length][]; - OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0); - OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0); - OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0); + for (int i = 0; i < encodings.length; i++) { + input_ids0[i] = encodings[i].getIds(); + attention_mask0[i] = encodings[i].getAttentionMask(); + token_type_ids0[i] = encodings[i].getTypeIds(); + } + + OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0); + OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0); + OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0); - Map modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask, - "token_type_ids", tokenTypeIds); + Map modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask, + "token_type_ids", tokenTypeIds); - modelInputs = removeUnknownModelInputs(modelInputs); + modelInputs = removeUnknownModelInputs(modelInputs); - // The Run result object is AutoCloseable to prevent references from leaking - // out. Once the Result object is - // closed, all it’s child OnnxValues are closed too. - try (OrtSession.Result results = this.session.run(modelInputs)) { + // The Run result object is AutoCloseable to prevent references from + // leaking + // out. Once the Result object is + // closed, all it’s child OnnxValues are closed too. + try (OrtSession.Result results = this.session.run(modelInputs)) { - // OnnxValue lastHiddenState = results.get(0); - OnnxValue lastHiddenState = results.get(this.modelOutputName).get(); + // OnnxValue lastHiddenState = results.get(0); + OnnxValue lastHiddenState = results.get(this.modelOutputName).get(); - // 0 - batch_size (1..x) - // 1 - sequence_length (128) - // 2 - embedding dimensions (384) - float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue(); + // 0 - batch_size (1..x) + // 1 - sequence_length (128) + // 2 - embedding dimensions (384) + float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue(); - try (NDManager manager = NDManager.newBaseManager()) { - NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager); - NDArray ndAttentionMask = manager.create(attention_mask0); + try (NDManager manager = NDManager.newBaseManager()) { + NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager); + NDArray ndAttentionMask = manager.create(attention_mask0); - NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask); + NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask); - for (int i = 0; i < embedding.size(0); i++) { - resultEmbeddings.add(embedding.get(i).toFloatArray()); + for (int i = 0; i < embedding.size(0); i++) { + resultEmbeddings.add(embedding.get(i).toFloatArray()); + } + } } } - } - } - catch (OrtException ex) { - throw new RuntimeException(ex); - } + catch (OrtException ex) { + throw new RuntimeException(ex); + } - var indexCounter = new AtomicInteger(0); - return new EmbeddingResponse( - resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList()); + var indexCounter = new AtomicInteger(0); + + EmbeddingResponse embeddingResponse = new EmbeddingResponse( + resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList()); + observationContext.setResponse(embeddingResponse); + + return embeddingResponse; + }); } private Map removeUnknownModelInputs(Map modelInputs) { @@ -347,4 +387,13 @@ private static Resource toResource(String uri) { return new DefaultResourceLoader().getResource(uri); } + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + } \ No newline at end of file diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java new file mode 100644 index 00000000000..ec3c9c5ad50 --- /dev/null +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.transformers; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + +/** + * Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}. + * + * @author Christian Tzolov + */ +@SpringBootTest(classes = TransformersEmbeddingModelObservationTests.Config.class) +public class TransformersEmbeddingModelObservationTests { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + TransformersEmbeddingModel embeddingModel; + + @Test + void observationForEmbeddingOperation() { + + var options = EmbeddingOptionsBuilder.builder().withModel("bert-base-uncased").build(); + + EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); + + EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).isNotEmpty(); + + EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("embedding " + "bert-base-uncased") + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.EMBEDDING.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.ONNX.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "bert-base-uncased") + // .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), + // responseMetadata.getModel()) + // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(), + // "1536") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public TransformersEmbeddingModel openAiEmbeddingModel(TestObservationRegistry observationRegistry) { + return new TransformersEmbeddingModel(MetadataMode.NONE, observationRegistry); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index 3ff73b98b48..aff11d6af03 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -39,7 +39,8 @@ public enum AiProvider { MINIMAX("minimax"), MOONSHOT("moonshot"), SPRING_AI("spring_ai"), - VERTEX_AI("vertex_ai"); + VERTEX_AI("vertex_ai"), + ONNX("onnx"); private final String value; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java index 2964ffb297e..583d568631c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java @@ -15,10 +15,9 @@ */ package org.springframework.ai.autoconfigure.transformers; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import ai.onnxruntime.OrtSession; - +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.transformers.TransformersEmbeddingModel; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -26,6 +25,10 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.onnxruntime.OrtSession; +import io.micrometer.observation.ObservationRegistry; + /** * @author Christian Tzolov */ @@ -38,9 +41,12 @@ public class TransformersEmbeddingModelAutoConfiguration { @ConditionalOnMissingBean @ConditionalOnProperty(prefix = TransformersEmbeddingModelProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public TransformersEmbeddingModel embeddingModel(TransformersEmbeddingModelProperties properties) { + public TransformersEmbeddingModel embeddingModel(TransformersEmbeddingModelProperties properties, + ObjectProvider observationRegistry, + ObjectProvider observationConvention) { - TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel(properties.getMetadataMode()); + TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel(properties.getMetadataMode(), + observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); embeddingModel.setDisableCaching(!properties.getCache().isEnabled()); embeddingModel.setResourceCacheDirectory(properties.getCache().getDirectory()); @@ -54,6 +60,8 @@ public TransformersEmbeddingModel embeddingModel(TransformersEmbeddingModelPrope embeddingModel.setModelOutputName(properties.getOnnx().getModelOutputName()); + observationConvention.ifAvailable(embeddingModel::setObservationConvention); + return embeddingModel; }