Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -291,39 +291,44 @@ public EmbeddingResponse call(EmbeddingRequest request) {
token_type_ids0[i] = encodings[i].getTypeIds();
}

OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
try (OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);) {

Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
"token_type_ids", tokenTypeIds);
Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask",
attentionMask, "token_type_ids", tokenTypeIds);

modelInputs = removeUnknownModelInputs(modelInputs);
modelInputs = removeUnknownModelInputs(modelInputs);

// The Run result object is AutoCloseable to prevent references from
// leaking
// out. Once the Result object is
// closed, all it’s child OnnxValues are closed too.
try (OrtSession.Result results = this.session.run(modelInputs)) {
// The Run result object is AutoCloseable to prevent references
// from leaking out. Once the Result object is
// closed, all it’s child OnnxValues are closed too.
try (OrtSession.Result results = this.session.run(modelInputs)) {

// OnnxValue lastHiddenState = results.get(0);
OnnxValue lastHiddenState = results.get(this.modelOutputName).get();
// OnnxValue lastHiddenState = results.get(0);
OnnxValue lastHiddenState = results.get(this.modelOutputName).get();

// 0 - batch_size (1..x)
// 1 - sequence_length (128)
// 2 - embedding dimensions (384)
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
// 0 - batch_size (1..x)
// 1 - sequence_length (128)
// 2 - embedding dimensions (384)
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();

try (NDManager manager = NDManager.newBaseManager()) {
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
NDArray ndAttentionMask = manager.create(attention_mask0);
try (NDManager manager = NDManager.newBaseManager()) {
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
NDArray ndAttentionMask = manager.create(attention_mask0);

NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);

for (int i = 0; i < embedding.size(0); i++) {
resultEmbeddings.add(embedding.get(i).toFloatArray());
for (int i = 0; i < embedding.size(0); i++) {
resultEmbeddings.add(embedding.get(i).toFloatArray());
}
}
}
finally {
inputIds.close();
attentionMask.close();
tokenTypeIds.close();
}
}
}
catch (OrtException ex) {
Expand Down