From 1e062ff3cb8572b349be70f0e20904a52ae0d9c9 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 8 Oct 2024 11:28:29 +0200 Subject: [PATCH] Improve resource management in TransformersEmbeddingModel - Prevent potential resource leaks - Wrap OnnxTensor creation in try-with-resources block - Explicitly close inputIds, attentionMask, and tokenTypeIds tensors Resolves #1427 --- .../TransformersEmbeddingModel.java | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java index d86460526e2..05a0fb514b8 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java @@ -291,39 +291,44 @@ public EmbeddingResponse call(EmbeddingRequest request) { token_type_ids0[i] = encodings[i].getTypeIds(); } - OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0); - OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0); - OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0); + try (OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0); + OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0); + OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);) { - Map modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask, - "token_type_ids", tokenTypeIds); + Map modelInputs = Map.of("input_ids", inputIds, "attention_mask", + attentionMask, "token_type_ids", tokenTypeIds); - modelInputs = removeUnknownModelInputs(modelInputs); + modelInputs = removeUnknownModelInputs(modelInputs); - // The Run result object is AutoCloseable to prevent references from - // leaking - // out. Once the Result object is - // closed, all it’s child OnnxValues are closed too. - try (OrtSession.Result results = this.session.run(modelInputs)) { + // The Run result object is AutoCloseable to prevent references + // from leaking out. Once the Result object is + // closed, all it’s child OnnxValues are closed too. + try (OrtSession.Result results = this.session.run(modelInputs)) { - // OnnxValue lastHiddenState = results.get(0); - OnnxValue lastHiddenState = results.get(this.modelOutputName).get(); + // OnnxValue lastHiddenState = results.get(0); + OnnxValue lastHiddenState = results.get(this.modelOutputName).get(); - // 0 - batch_size (1..x) - // 1 - sequence_length (128) - // 2 - embedding dimensions (384) - float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue(); + // 0 - batch_size (1..x) + // 1 - sequence_length (128) + // 2 - embedding dimensions (384) + float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue(); - try (NDManager manager = NDManager.newBaseManager()) { - NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager); - NDArray ndAttentionMask = manager.create(attention_mask0); + try (NDManager manager = NDManager.newBaseManager()) { + NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager); + NDArray ndAttentionMask = manager.create(attention_mask0); - NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask); + NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask); - for (int i = 0; i < embedding.size(0); i++) { - resultEmbeddings.add(embedding.get(i).toFloatArray()); + for (int i = 0; i < embedding.size(0); i++) { + resultEmbeddings.add(embedding.get(i).toFloatArray()); + } } } + finally { + inputIds.close(); + attentionMask.close(); + tokenTypeIds.close(); + } } } catch (OrtException ex) {