Skip to content

Commit 1e062ff

Browse files
committed
Improve resource management in TransformersEmbeddingModel
- Prevent potential resource leaks - Wrap OnnxTensor creation in try-with-resources block - Explicitly close inputIds, attentionMask, and tokenTypeIds tensors Resolves #1427
1 parent 1019343 commit 1e062ff

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -291,39 +291,44 @@ public EmbeddingResponse call(EmbeddingRequest request) {
291291
token_type_ids0[i] = encodings[i].getTypeIds();
292292
}
293293

294-
OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
295-
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
296-
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
294+
try (OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
295+
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
296+
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);) {
297297

298-
Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
299-
"token_type_ids", tokenTypeIds);
298+
Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask",
299+
attentionMask, "token_type_ids", tokenTypeIds);
300300

301-
modelInputs = removeUnknownModelInputs(modelInputs);
301+
modelInputs = removeUnknownModelInputs(modelInputs);
302302

303-
// The Run result object is AutoCloseable to prevent references from
304-
// leaking
305-
// out. Once the Result object is
306-
// closed, all it’s child OnnxValues are closed too.
307-
try (OrtSession.Result results = this.session.run(modelInputs)) {
303+
// The Run result object is AutoCloseable to prevent references
304+
// from leaking out. Once the Result object is
305+
// closed, all it’s child OnnxValues are closed too.
306+
try (OrtSession.Result results = this.session.run(modelInputs)) {
308307

309-
// OnnxValue lastHiddenState = results.get(0);
310-
OnnxValue lastHiddenState = results.get(this.modelOutputName).get();
308+
// OnnxValue lastHiddenState = results.get(0);
309+
OnnxValue lastHiddenState = results.get(this.modelOutputName).get();
311310

312-
// 0 - batch_size (1..x)
313-
// 1 - sequence_length (128)
314-
// 2 - embedding dimensions (384)
315-
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
311+
// 0 - batch_size (1..x)
312+
// 1 - sequence_length (128)
313+
// 2 - embedding dimensions (384)
314+
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
316315

317-
try (NDManager manager = NDManager.newBaseManager()) {
318-
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
319-
NDArray ndAttentionMask = manager.create(attention_mask0);
316+
try (NDManager manager = NDManager.newBaseManager()) {
317+
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
318+
NDArray ndAttentionMask = manager.create(attention_mask0);
320319

321-
NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
320+
NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
322321

323-
for (int i = 0; i < embedding.size(0); i++) {
324-
resultEmbeddings.add(embedding.get(i).toFloatArray());
322+
for (int i = 0; i < embedding.size(0); i++) {
323+
resultEmbeddings.add(embedding.get(i).toFloatArray());
324+
}
325325
}
326326
}
327+
finally {
328+
inputIds.close();
329+
attentionMask.close();
330+
tokenTypeIds.close();
331+
}
327332
}
328333
}
329334
catch (OrtException ex) {

0 commit comments

Comments
 (0)