diff --git a/src/finbert_qa.py b/src/finbert_qa.py index c6c628f..648986c 100644 --- a/src/finbert_qa.py +++ b/src/finbert_qa.py @@ -842,11 +842,11 @@ def predict(self, model, q_text, cands): pred = pred.detach().cpu().numpy() # Append relevant scores to list (where label = 1) scores.append(pred[:,1][0]) - # Get the indices of the sorted similarity scores - sorted_index = np.argsort(scores)[::-1] - # Get the list of docid from the sorted indices - ranked_ans = list(cands_id[sorted_index]) - sorted_scores = list(np.around(sorted(scores, reverse=True),decimals=3)) + # Get the indices of the sorted similarity scores + sorted_index = np.argsort(scores)[::-1] + # Get the list of docid from the sorted indices + ranked_ans = list(cands_id[sorted_index]) + sorted_scores = list(np.around(sorted(scores, reverse=True),decimals=3)) return ranked_ans, sorted_scores