Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions models/spring-ai-transformers/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,40 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.nlp.preprocess.Tokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.nlp.preprocess.Tokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import io.micrometer.observation.ObservationRegistry;

/**
* https://www.sbert.net/index.html https://www.sbert.net/docs/pretrained_models.html
*
Expand All @@ -60,6 +66,8 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement

private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class);

private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();

// ONNX tokenizer for the all-MiniLM-L6-v2 generative
public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";

Expand Down Expand Up @@ -126,13 +134,29 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement

private Set<String> onnxModelInputs;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public TransformersEmbeddingModel() {
this(MetadataMode.NONE);
}

public TransformersEmbeddingModel(MetadataMode metadataMode) {
this(metadataMode, ObservationRegistry.NOOP);
}

public TransformersEmbeddingModel(MetadataMode metadataMode, ObservationRegistry observationRegistry) {
Assert.notNull(metadataMode, "Metadata mode should not be null");
Assert.notNull(observationRegistry, "Observation registry should not be null");
this.metadataMode = metadataMode;
this.observationRegistry = observationRegistry;
}

public void setTokenizerOptions(Map<String, String> tokenizerOptions) {
Expand Down Expand Up @@ -231,7 +255,7 @@ public EmbeddingResponse embedForResponse(List<String> texts) {

@Override
public List<float[]> embed(List<String> texts) {
return this.call(new EmbeddingRequest(texts, EmbeddingOptions.EMPTY))
return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
.getResults()
.stream()
.map(e -> e.getOutput())
Expand All @@ -241,63 +265,79 @@ public List<float[]> embed(List<String> texts) {
@Override
public EmbeddingResponse call(EmbeddingRequest request) {

List<float[]> resultEmbeddings = new ArrayList<>();
var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.provider(AiProvider.ONNX.value())
.requestOptions(request.getOptions())
.build();

try {
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
List<float[]> resultEmbeddings = new ArrayList<>();

Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions());
try {

long[][] input_ids0 = new long[encodings.length][];
long[][] attention_mask0 = new long[encodings.length][];
long[][] token_type_ids0 = new long[encodings.length][];
Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions());

for (int i = 0; i < encodings.length; i++) {
input_ids0[i] = encodings[i].getIds();
attention_mask0[i] = encodings[i].getAttentionMask();
token_type_ids0[i] = encodings[i].getTypeIds();
}
long[][] input_ids0 = new long[encodings.length][];
long[][] attention_mask0 = new long[encodings.length][];
long[][] token_type_ids0 = new long[encodings.length][];

OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
for (int i = 0; i < encodings.length; i++) {
input_ids0[i] = encodings[i].getIds();
attention_mask0[i] = encodings[i].getAttentionMask();
token_type_ids0[i] = encodings[i].getTypeIds();
}

OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);

Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
"token_type_ids", tokenTypeIds);
Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
"token_type_ids", tokenTypeIds);

modelInputs = removeUnknownModelInputs(modelInputs);
modelInputs = removeUnknownModelInputs(modelInputs);

// The Run result object is AutoCloseable to prevent references from leaking
// out. Once the Result object is
// closed, all it’s child OnnxValues are closed too.
try (OrtSession.Result results = this.session.run(modelInputs)) {
// The Run result object is AutoCloseable to prevent references from
// leaking
// out. Once the Result object is
// closed, all it’s child OnnxValues are closed too.
try (OrtSession.Result results = this.session.run(modelInputs)) {

// OnnxValue lastHiddenState = results.get(0);
OnnxValue lastHiddenState = results.get(this.modelOutputName).get();
// OnnxValue lastHiddenState = results.get(0);
OnnxValue lastHiddenState = results.get(this.modelOutputName).get();

// 0 - batch_size (1..x)
// 1 - sequence_length (128)
// 2 - embedding dimensions (384)
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
// 0 - batch_size (1..x)
// 1 - sequence_length (128)
// 2 - embedding dimensions (384)
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();

try (NDManager manager = NDManager.newBaseManager()) {
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
NDArray ndAttentionMask = manager.create(attention_mask0);
try (NDManager manager = NDManager.newBaseManager()) {
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
NDArray ndAttentionMask = manager.create(attention_mask0);

NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);

for (int i = 0; i < embedding.size(0); i++) {
resultEmbeddings.add(embedding.get(i).toFloatArray());
for (int i = 0; i < embedding.size(0); i++) {
resultEmbeddings.add(embedding.get(i).toFloatArray());
}
}
}
}
}
}
catch (OrtException ex) {
throw new RuntimeException(ex);
}
catch (OrtException ex) {
throw new RuntimeException(ex);
}

var indexCounter = new AtomicInteger(0);
return new EmbeddingResponse(
resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList());
var indexCounter = new AtomicInteger(0);

EmbeddingResponse embeddingResponse = new EmbeddingResponse(
resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList());
observationContext.setResponse(embeddingResponse);

return embeddingResponse;
});
}

private Map<String, OnnxTensor> removeUnknownModelInputs(Map<String, OnnxTensor> modelInputs) {
Expand Down Expand Up @@ -347,4 +387,13 @@ private static Resource toResource(String uri) {
return new DefaultResourceLoader().getResource(uri);
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.transformers;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;

import org.junit.jupiter.api.Test;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;

/**
* Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}.
*
* @author Christian Tzolov
*/
@SpringBootTest(classes = TransformersEmbeddingModelObservationTests.Config.class)
public class TransformersEmbeddingModelObservationTests {

@Autowired
TestObservationRegistry observationRegistry;

@Autowired
TransformersEmbeddingModel embeddingModel;

@Test
void observationForEmbeddingOperation() {

var options = EmbeddingOptionsBuilder.builder().withModel("bert-base-uncased").build();

EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options);

EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest);
assertThat(embeddingResponse.getResults()).isNotEmpty();

EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

TestObservationRegistryAssert.assertThat(observationRegistry)
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME)
.that()
.hasContextualNameEqualTo("embedding " + "bert-base-uncased")
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
AiOperationType.EMBEDDING.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.ONNX.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "bert-base-uncased")
// .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
// responseMetadata.getModel())
// .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(),
// "1536")
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
.hasBeenStarted()
.hasBeenStopped();
}

@SpringBootConfiguration
static class Config {

@Bean
public TestObservationRegistry observationRegistry() {
return TestObservationRegistry.create();
}

@Bean
public TransformersEmbeddingModel openAiEmbeddingModel(TestObservationRegistry observationRegistry) {
return new TransformersEmbeddingModel(MetadataMode.NONE, observationRegistry);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public enum AiProvider {
MINIMAX("minimax"),
MOONSHOT("moonshot"),
SPRING_AI("spring_ai"),
VERTEX_AI("vertex_ai");
VERTEX_AI("vertex_ai"),
ONNX("onnx");

private final String value;

Expand Down
Loading