From 073f4575e13ae500371230d1227a357743f0117d Mon Sep 17 00:00:00 2001 From: Gautam Dudeja Date: Mon, 27 Jun 2022 15:09:26 -0400 Subject: [PATCH] Update finbert_qa.py Modify predict function to sort after computing scores of all answers --- src/finbert_qa.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/finbert_qa.py b/src/finbert_qa.py index c6c628f..648986c 100644 --- a/src/finbert_qa.py +++ b/src/finbert_qa.py @@ -842,11 +842,11 @@ def predict(self, model, q_text, cands): pred = pred.detach().cpu().numpy() # Append relevant scores to list (where label = 1) scores.append(pred[:,1][0]) - # Get the indices of the sorted similarity scores - sorted_index = np.argsort(scores)[::-1] - # Get the list of docid from the sorted indices - ranked_ans = list(cands_id[sorted_index]) - sorted_scores = list(np.around(sorted(scores, reverse=True),decimals=3)) + # Get the indices of the sorted similarity scores + sorted_index = np.argsort(scores)[::-1] + # Get the list of docid from the sorted indices + ranked_ans = list(cands_id[sorted_index]) + sorted_scores = list(np.around(sorted(scores, reverse=True),decimals=3)) return ranked_ans, sorted_scores