Skip to content

Commit b6165c6

Browse files
committed
Add observability support to TransformersEmbeddingModel
- Integrate ObservationRegistry and EmbeddingModelObservationConvention - Update TransformersEmbeddingModel to use observations - Add TransformersEmbeddingModelObservationTests - Update TransformersEmbeddingModelAutoConfiguration for observation support - Add ONNX to AiProvider enum
1 parent 1606383 commit b6165c6

File tree

5 files changed

+230
-62
lines changed

5 files changed

+230
-62
lines changed

models/spring-ai-transformers/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@
8585
<scope>test</scope>
8686
</dependency>
8787

88+
<dependency>
89+
<groupId>io.micrometer</groupId>
90+
<artifactId>micrometer-observation-test</artifactId>
91+
<scope>test</scope>
92+
</dependency>
93+
8894
</dependencies>
8995

9096
</project>

models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java

Lines changed: 105 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -23,34 +23,40 @@
2323
import java.util.concurrent.atomic.AtomicInteger;
2424
import java.util.stream.Collectors;
2525

26-
import ai.djl.huggingface.tokenizers.Encoding;
27-
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
28-
import ai.djl.modality.nlp.preprocess.Tokenizer;
29-
import ai.djl.ndarray.NDArray;
30-
import ai.djl.ndarray.NDManager;
31-
import ai.djl.ndarray.types.DataType;
32-
import ai.djl.ndarray.types.Shape;
33-
import ai.onnxruntime.OnnxTensor;
34-
import ai.onnxruntime.OnnxValue;
35-
import ai.onnxruntime.OrtEnvironment;
36-
import ai.onnxruntime.OrtException;
37-
import ai.onnxruntime.OrtSession;
3826
import org.apache.commons.logging.Log;
3927
import org.apache.commons.logging.LogFactory;
40-
4128
import org.springframework.ai.document.Document;
4229
import org.springframework.ai.document.MetadataMode;
4330
import org.springframework.ai.embedding.AbstractEmbeddingModel;
4431
import org.springframework.ai.embedding.Embedding;
45-
import org.springframework.ai.embedding.EmbeddingOptions;
32+
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
4633
import org.springframework.ai.embedding.EmbeddingRequest;
4734
import org.springframework.ai.embedding.EmbeddingResponse;
35+
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
36+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
37+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
38+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
39+
import org.springframework.ai.observation.conventions.AiProvider;
4840
import org.springframework.beans.factory.InitializingBean;
4941
import org.springframework.core.io.DefaultResourceLoader;
5042
import org.springframework.core.io.Resource;
5143
import org.springframework.util.Assert;
5244
import org.springframework.util.StringUtils;
5345

46+
import ai.djl.huggingface.tokenizers.Encoding;
47+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
48+
import ai.djl.modality.nlp.preprocess.Tokenizer;
49+
import ai.djl.ndarray.NDArray;
50+
import ai.djl.ndarray.NDManager;
51+
import ai.djl.ndarray.types.DataType;
52+
import ai.djl.ndarray.types.Shape;
53+
import ai.onnxruntime.OnnxTensor;
54+
import ai.onnxruntime.OnnxValue;
55+
import ai.onnxruntime.OrtEnvironment;
56+
import ai.onnxruntime.OrtException;
57+
import ai.onnxruntime.OrtSession;
58+
import io.micrometer.observation.ObservationRegistry;
59+
5460
/**
5561
* https://www.sbert.net/index.html https://www.sbert.net/docs/pretrained_models.html
5662
*
@@ -60,6 +66,8 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement
6066

6167
private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class);
6268

69+
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
70+
6371
// ONNX tokenizer for the all-MiniLM-L6-v2 generative
6472
public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
6573

@@ -126,13 +134,29 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement
126134

127135
private Set<String> onnxModelInputs;
128136

137+
/**
138+
* Observation registry used for instrumentation.
139+
*/
140+
private final ObservationRegistry observationRegistry;
141+
142+
/**
143+
* Conventions to use for generating observations.
144+
*/
145+
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
146+
129147
public TransformersEmbeddingModel() {
130148
this(MetadataMode.NONE);
131149
}
132150

133151
public TransformersEmbeddingModel(MetadataMode metadataMode) {
152+
this(metadataMode, ObservationRegistry.NOOP);
153+
}
154+
155+
public TransformersEmbeddingModel(MetadataMode metadataMode, ObservationRegistry observationRegistry) {
134156
Assert.notNull(metadataMode, "Metadata mode should not be null");
157+
Assert.notNull(observationRegistry, "Observation registry should not be null");
135158
this.metadataMode = metadataMode;
159+
this.observationRegistry = observationRegistry;
136160
}
137161

138162
public void setTokenizerOptions(Map<String, String> tokenizerOptions) {
@@ -231,7 +255,7 @@ public EmbeddingResponse embedForResponse(List<String> texts) {
231255

232256
@Override
233257
public List<float[]> embed(List<String> texts) {
234-
return this.call(new EmbeddingRequest(texts, EmbeddingOptions.EMPTY))
258+
return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
235259
.getResults()
236260
.stream()
237261
.map(e -> e.getOutput())
@@ -241,63 +265,79 @@ public List<float[]> embed(List<String> texts) {
241265
@Override
242266
public EmbeddingResponse call(EmbeddingRequest request) {
243267

244-
List<float[]> resultEmbeddings = new ArrayList<>();
268+
var observationContext = EmbeddingModelObservationContext.builder()
269+
.embeddingRequest(request)
270+
.provider(AiProvider.ONNX.value())
271+
.requestOptions(request.getOptions())
272+
.build();
245273

246-
try {
274+
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
275+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
276+
this.observationRegistry)
277+
.observe(() -> {
278+
List<float[]> resultEmbeddings = new ArrayList<>();
247279

248-
Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions());
280+
try {
249281

250-
long[][] input_ids0 = new long[encodings.length][];
251-
long[][] attention_mask0 = new long[encodings.length][];
252-
long[][] token_type_ids0 = new long[encodings.length][];
282+
Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions());
253283

254-
for (int i = 0; i < encodings.length; i++) {
255-
input_ids0[i] = encodings[i].getIds();
256-
attention_mask0[i] = encodings[i].getAttentionMask();
257-
token_type_ids0[i] = encodings[i].getTypeIds();
258-
}
284+
long[][] input_ids0 = new long[encodings.length][];
285+
long[][] attention_mask0 = new long[encodings.length][];
286+
long[][] token_type_ids0 = new long[encodings.length][];
259287

260-
OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
261-
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
262-
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
288+
for (int i = 0; i < encodings.length; i++) {
289+
input_ids0[i] = encodings[i].getIds();
290+
attention_mask0[i] = encodings[i].getAttentionMask();
291+
token_type_ids0[i] = encodings[i].getTypeIds();
292+
}
293+
294+
OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
295+
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
296+
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
263297

264-
Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
265-
"token_type_ids", tokenTypeIds);
298+
Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
299+
"token_type_ids", tokenTypeIds);
266300

267-
modelInputs = removeUnknownModelInputs(modelInputs);
301+
modelInputs = removeUnknownModelInputs(modelInputs);
268302

269-
// The Run result object is AutoCloseable to prevent references from leaking
270-
// out. Once the Result object is
271-
// closed, all it’s child OnnxValues are closed too.
272-
try (OrtSession.Result results = this.session.run(modelInputs)) {
303+
// The Run result object is AutoCloseable to prevent references from
304+
// leaking
305+
// out. Once the Result object is
306+
// closed, all it’s child OnnxValues are closed too.
307+
try (OrtSession.Result results = this.session.run(modelInputs)) {
273308

274-
// OnnxValue lastHiddenState = results.get(0);
275-
OnnxValue lastHiddenState = results.get(this.modelOutputName).get();
309+
// OnnxValue lastHiddenState = results.get(0);
310+
OnnxValue lastHiddenState = results.get(this.modelOutputName).get();
276311

277-
// 0 - batch_size (1..x)
278-
// 1 - sequence_length (128)
279-
// 2 - embedding dimensions (384)
280-
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
312+
// 0 - batch_size (1..x)
313+
// 1 - sequence_length (128)
314+
// 2 - embedding dimensions (384)
315+
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
281316

282-
try (NDManager manager = NDManager.newBaseManager()) {
283-
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
284-
NDArray ndAttentionMask = manager.create(attention_mask0);
317+
try (NDManager manager = NDManager.newBaseManager()) {
318+
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
319+
NDArray ndAttentionMask = manager.create(attention_mask0);
285320

286-
NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
321+
NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
287322

288-
for (int i = 0; i < embedding.size(0); i++) {
289-
resultEmbeddings.add(embedding.get(i).toFloatArray());
323+
for (int i = 0; i < embedding.size(0); i++) {
324+
resultEmbeddings.add(embedding.get(i).toFloatArray());
325+
}
326+
}
290327
}
291328
}
292-
}
293-
}
294-
catch (OrtException ex) {
295-
throw new RuntimeException(ex);
296-
}
329+
catch (OrtException ex) {
330+
throw new RuntimeException(ex);
331+
}
297332

298-
var indexCounter = new AtomicInteger(0);
299-
return new EmbeddingResponse(
300-
resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList());
333+
var indexCounter = new AtomicInteger(0);
334+
335+
EmbeddingResponse embeddingResponse = new EmbeddingResponse(
336+
resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList());
337+
observationContext.setResponse(embeddingResponse);
338+
339+
return embeddingResponse;
340+
});
301341
}
302342

303343
private Map<String, OnnxTensor> removeUnknownModelInputs(Map<String, OnnxTensor> modelInputs) {
@@ -347,4 +387,13 @@ private static Resource toResource(String uri) {
347387
return new DefaultResourceLoader().getResource(uri);
348388
}
349389

390+
/**
391+
* Use the provided convention for reporting observation data
392+
* @param observationConvention The provided convention
393+
*/
394+
public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
395+
Assert.notNull(observationConvention, "observationConvention cannot be null");
396+
this.observationConvention = observationConvention;
397+
}
398+
350399
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.transformers;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.util.List;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.springframework.ai.document.MetadataMode;
24+
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
25+
import org.springframework.ai.embedding.EmbeddingRequest;
26+
import org.springframework.ai.embedding.EmbeddingResponse;
27+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
28+
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
29+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames;
30+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames;
31+
import org.springframework.ai.observation.conventions.AiOperationType;
32+
import org.springframework.ai.observation.conventions.AiProvider;
33+
import org.springframework.beans.factory.annotation.Autowired;
34+
import org.springframework.boot.SpringBootConfiguration;
35+
import org.springframework.boot.test.context.SpringBootTest;
36+
import org.springframework.context.annotation.Bean;
37+
38+
import io.micrometer.observation.tck.TestObservationRegistry;
39+
import io.micrometer.observation.tck.TestObservationRegistryAssert;
40+
41+
/**
42+
* Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}.
43+
*
44+
* @author Christian Tzolov
45+
*/
46+
@SpringBootTest(classes = TransformersEmbeddingModelObservationTests.Config.class)
47+
public class TransformersEmbeddingModelObservationTests {
48+
49+
@Autowired
50+
TestObservationRegistry observationRegistry;
51+
52+
@Autowired
53+
TransformersEmbeddingModel embeddingModel;
54+
55+
@Test
56+
void observationForEmbeddingOperation() {
57+
58+
var options = EmbeddingOptionsBuilder.builder().withModel("bert-base-uncased").build();
59+
60+
EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options);
61+
62+
EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest);
63+
assertThat(embeddingResponse.getResults()).isNotEmpty();
64+
65+
EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata();
66+
assertThat(responseMetadata).isNotNull();
67+
68+
TestObservationRegistryAssert.assertThat(observationRegistry)
69+
.doesNotHaveAnyRemainingCurrentObservation()
70+
.hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME)
71+
.that()
72+
.hasContextualNameEqualTo("embedding " + "bert-base-uncased")
73+
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
74+
AiOperationType.EMBEDDING.value())
75+
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.ONNX.value())
76+
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "bert-base-uncased")
77+
// .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
78+
// responseMetadata.getModel())
79+
// .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(),
80+
// "1536")
81+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
82+
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
83+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
84+
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
85+
.hasBeenStarted()
86+
.hasBeenStopped();
87+
}
88+
89+
@SpringBootConfiguration
90+
static class Config {
91+
92+
@Bean
93+
public TestObservationRegistry observationRegistry() {
94+
return TestObservationRegistry.create();
95+
}
96+
97+
@Bean
98+
public TransformersEmbeddingModel openAiEmbeddingModel(TestObservationRegistry observationRegistry) {
99+
return new TransformersEmbeddingModel(MetadataMode.NONE, observationRegistry);
100+
}
101+
102+
}
103+
104+
}

spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ public enum AiProvider {
3939
MINIMAX("minimax"),
4040
MOONSHOT("moonshot"),
4141
SPRING_AI("spring_ai"),
42-
VERTEX_AI("vertex_ai");
42+
VERTEX_AI("vertex_ai"),
43+
ONNX("onnx");
4344

4445
private final String value;
4546

0 commit comments

Comments
 (0)