Skip to content

Commit b22f03d

Browse files
committed
fix(similarity): maximize score instead of minimize
1 parent 4b6681a commit b22f03d

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

signwriting_evaluation/metrics/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ def corpus_score(self, hypotheses: list[str], references: list[list[str]]) -> fl
2929
transpose_references = list(zip(*references))
3030
return sum(self.score_max(h, r) for h, r in zip(hypotheses, transpose_references)) / len(hypotheses)
3131

32-
def score_all(self, hypotheses: list[str], references: list[str]) -> list[list[float]]:
32+
def score_all(self, hypotheses: list[str], references: list[str], progress_bar=True) -> list[list[float]]:
3333
# Default implementation: call the score function for each hypothesis-reference pair
34-
return [[self.score(h, r) for r in references] for h in tqdm(hypotheses, disable=len(hypotheses) == 1)]
34+
return [[self.score(h, r) for r in references]
35+
for h in tqdm(hypotheses, disable=not progress_bar or len(hypotheses) == 1)]
3536

3637
def __str__(self):
3738
return self.name

signwriting_evaluation/metrics/similarity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def score(self, hypothesis: str, reference: str) -> float:
117117
reference_signs += [""] * (max_length - len(reference_signs))
118118

119119
# Match each hypothesis sign with each reference sign
120-
cost_matrix = self.score_all(hypothesis_signs, reference_signs)
121-
row_ind, col_ind = linear_sum_assignment(cost_matrix)
120+
cost_matrix = self.score_all(hypothesis_signs, reference_signs, progress_bar=False)
121+
row_ind, col_ind = linear_sum_assignment(1 - np.array(cost_matrix))
122122
pairs = list(zip(row_ind, col_ind))
123123
values = [cost_matrix[row][col] for row, col in pairs]
124124
return sum(values) / len(values)

signwriting_evaluation/metrics/test_similarity.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def test_multi_sign_score(self):
4444
self.assertIsInstance(score, float)
4545
self.assertAlmostEqual(score, 0.8326259781509948 / 2)
4646

47+
def test_multi_sign_score_is_order_invariant(self):
48+
sign_1 = "M530x538S17600508x462S15a11493x494S20e00488x510S22f03469x517"
49+
sign_2 = "M530x538S17600508x462S12a11493x494S20e00488x510S22f13469x517"
50+
hypothesis = f"{sign_1} {sign_2}"
51+
reference = f"{sign_2} {sign_1}"
52+
score = self.metric.score(hypothesis, reference)
53+
self.assertAlmostEqual(score, 1)
54+
4755
def test_bad_fsw_equals_0(self):
4856
bad_fsw = "M<s><s>M<s>p483"
4957
score = self.metric.corpus_score([bad_fsw], [[bad_fsw]])

0 commit comments

Comments
 (0)