Skip to content

Commit caed1f0

Browse files
committed
lint(): fix lint errors
1 parent 8746e96 commit caed1f0

File tree

5 files changed

+44
-45
lines changed

5 files changed

+44
-45
lines changed

signwriting_evaluation/metrics/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ def score_all(self, hypotheses: Sequence[str], references: Sequence[str], progre
5959
def score_self(self, hypotheses: Sequence[str], progress_bar=True) -> list[list[float]]:
6060
if not self.SYMMETRIC:
6161
return self.score_all(hypotheses, hypotheses, progress_bar)
62-
62+
6363
# For symmetric metrics, only compute upper triangle to avoid redundant calculations
6464
n = len(hypotheses)
6565
scores = np.eye(n, dtype=np.float16) # Initialize with diagonal 1
66-
66+
6767
total = n * (n - 1) // 2 # Exclude diagonal
6868
iterator = tqdm([(i, j) for i in range(n) for j in range(i + 1, n)],
6969
total=total, disable=not progress_bar or total == 1)
70-
70+
7171
for i, j in iterator:
7272
score = self.score(hypotheses[i], hypotheses[j])
7373
scores[i][j] = scores[j][i] = score

signwriting_evaluation/metrics/bleu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from signwriting.tokenizer import SignWritingTokenizer
44

5-
from signwriting_evaluation.metrics.base import SignWritingMetric
5+
from signwriting_evaluation.metrics.base import SignWritingMetric, validate_corpus_score_input
66

77

88
class SignWritingBLEU(SignWritingMetric):
@@ -23,7 +23,7 @@ def score(self, hypothesis: str, reference: str) -> float:
2323
return self.bleu.sentence_score(hypothesis, [reference]).score / 100
2424

2525
def corpus_score(self, hypotheses: list[str], references: list[list[str]]) -> float:
26-
self.validate_corpus_score_input(hypotheses, references)
26+
validate_corpus_score_input(hypotheses, references)
2727
hypotheses = [self.tokenize(h) for h in hypotheses]
2828
references = [[self.tokenize(r) for r in reference] for reference in references]
2929
return self.bleu.corpus_score(hypotheses, references).score / 100

signwriting_evaluation/metrics/chrf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sacrebleu.metrics import CHRF
22

3-
from signwriting_evaluation.metrics.base import SignWritingMetric
3+
from signwriting_evaluation.metrics.base import SignWritingMetric, validate_corpus_score_input
44

55

66
class SignWritingCHRF(SignWritingMetric):
@@ -15,5 +15,5 @@ def score(self, hypothesis: str, reference: str) -> float:
1515
return self.chrf.sentence_score(hypothesis, [reference]).score / 100
1616

1717
def corpus_score(self, hypotheses: list[str], references: list[list[str]]) -> float:
18-
self.validate_corpus_score_input(hypotheses, references)
18+
validate_corpus_score_input(hypotheses, references)
1919
return self.chrf.corpus_score(hypotheses, references).score / 100

signwriting_evaluation/metrics/similarity.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_shape_class_index(shape: int) -> Optional[int]:
3535

3636

3737
@cache
38-
def text_to_signs(text: str) -> tuple[str]:
38+
def text_to_signs(text: str) -> tuple[str, ...]:
3939
text_as_fsw = swu2fsw(text) # converts swu symbols to fsw, while keeping the fsw symbols if present
4040
return tuple(normalize_signwriting(text_as_fsw).split(" "))
4141

@@ -49,70 +49,71 @@ def get_symbol_attributes(symbol: str) -> SymbolAttributes:
4949
return SymbolAttributes(shape, facing, angle, parallel)
5050

5151

52+
@cache
5253
def fast_positional_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> float:
5354
# Unbelievably, this is faster than using numpy or scipy for simple Euclidean distance
5455
# It reduces the overhead of converting to numpy arrays when calculating distances
5556
dx = pos1[0] - pos2[0]
5657
dy = pos1[1] - pos2[1]
5758
return math.sqrt(dx * dx + dy * dy)
5859

60+
61+
ERROR_WEIGHT = {
62+
"shape": 5, # same weight as switching parallelization
63+
"facing": 5 / 3, # more important than angle, not as much as shape and orientation
64+
"angle": 5 / 24, # lowest importance out of the criteria
65+
"parallel": 5, # parallelization is 3 columns compare to 1 for the facing direction
66+
"positional": 1 / 10, # may be big values
67+
"normalized_factor": 1 / 2.5, # fitting shape of function
68+
"exp_factor": 1.5, # exponential distribution
69+
"class_penalty": 100, # big penalty for each class type passed
70+
}
71+
72+
73+
@cache
74+
def fast_symbol_distance(attributes1: SymbolAttributes, attributes2: SymbolAttributes) -> float:
75+
d_shape = (attributes1.shape - attributes2.shape) * ERROR_WEIGHT["shape"]
76+
d_facing = (attributes1.facing - attributes2.facing) * ERROR_WEIGHT["facing"]
77+
d_angle = (attributes1.angle - attributes2.angle) * ERROR_WEIGHT["angle"]
78+
d_parallel = (attributes1.parallel != attributes2.parallel) * ERROR_WEIGHT["parallel"]
79+
return math.sqrt(d_shape * d_shape + \
80+
d_facing * d_facing + \
81+
d_angle * d_angle + \
82+
d_parallel * d_parallel)
83+
84+
5985
fsw_to_sign = cache(fsw_to_sign)
6086

87+
6188
class SignWritingSimilarityMetric(SignWritingMetric):
6289
SYMMETRIC = True
6390

6491
def __init__(self):
6592
super().__init__("SymbolsDistances")
66-
self.weight = {
67-
"shape": 5, # same weight as switching parallelization
68-
"facing": 5 / 3, # more important than angle, not as much as shape and orientation
69-
"angle": 5 / 24, # lowest importance out of the criteria
70-
"parallel": 5, # parallelization is 3 columns compare to 1 for the facing direction
71-
"positional": 1 / 10, # may be big values
72-
"normalized_factor": 1 / 2.5, # fitting shape of function
73-
"exp_factor": 1.5, # exponential distribution
74-
"class_penalty": 100, # big penalty for each class type passed
75-
}
76-
7793
self.max_distance = self.calculate_distance({"symbol": "S10000", "position": (250, 250)},
7894
{"symbol": "S38b07", "position": (750, 750)})
7995

80-
def weight_vector(self, attributes: SymbolAttributes) -> Tuple[float, ...]:
81-
weighted_values = self.symbol_weight_vector * attributes
82-
return weighted_values
83-
84-
@cache
85-
def symbol_distance(self, attributes1: SymbolAttributes, attributes2: SymbolAttributes) -> float:
86-
d_shape = (attributes1.shape - attributes2.shape) * self.weight["shape"]
87-
d_facing = (attributes1.facing - attributes2.facing) * self.weight["facing"]
88-
d_angle = (attributes1.angle - attributes2.angle) * self.weight["angle"]
89-
d_parallel = (attributes1.parallel != attributes2.parallel) * self.weight["parallel"]
90-
return math.sqrt(d_shape * d_shape + \
91-
d_facing * d_facing + \
92-
d_angle * d_angle + \
93-
d_parallel * d_parallel)
94-
9596
def calculate_distance(self, hyp: SignSymbol, ref: SignSymbol) -> float:
9697
hyp_attributes = get_symbol_attributes(hyp['symbol'])
9798
ref_attributes = get_symbol_attributes(ref['symbol'])
9899

99-
symbols_distance = self.symbol_distance(hyp_attributes, ref_attributes)
100+
symbols_distance = fast_symbol_distance(hyp_attributes, ref_attributes)
100101

101102
position_euclidean = fast_positional_distance(hyp["position"], ref["position"])
102-
position_distance = self.weight["positional"] * position_euclidean
103+
position_distance = ERROR_WEIGHT["positional"] * position_euclidean
103104

104105
hyp_class = get_shape_class_index(hyp_attributes.shape)
105106
ref_class = get_shape_class_index(ref_attributes.shape)
106107

107108
if hyp_class is None or ref_class is None:
108109
return self.max_distance
109110

110-
class_penalty = abs(hyp_class - ref_class) * self.weight["class_penalty"]
111+
class_penalty = abs(hyp_class - ref_class) * ERROR_WEIGHT["class_penalty"]
111112

112113
return symbols_distance + position_distance + class_penalty
113114

114115
def normalized_distance(self, unnormalized: float) -> float:
115-
return pow(unnormalized / self.max_distance, self.weight["normalized_factor"])
116+
return pow(unnormalized / self.max_distance, ERROR_WEIGHT["normalized_factor"])
116117

117118
def symbols_score(self, hyp: SignSymbol, ref: SignSymbol) -> float:
118119
distance = self.calculate_distance(hyp, ref)
@@ -135,12 +136,10 @@ def error_rate(self, hyp: Sign, ref: Sign) -> float:
135136
cost_matrix = cost_matrix.reshape(len(hyp["symbols"]), -1)
136137
# Find the lowest cost matching
137138
row_ind, col_ind = linear_sum_assignment(cost_matrix)
138-
pairs = list(zip(row_ind, col_ind))
139-
# Print the matching and total cost
140-
values = [cost_matrix[row, col] for row, col in pairs]
141-
mean_cost = sum(values) / len(values)
139+
mean_cost = float(cost_matrix[row_ind, col_ind].mean())
140+
142141
length_error = self.length_acc(hyp, ref)
143-
length_weight = pow(length_error, self.weight["exp_factor"])
142+
length_weight = pow(length_error, ERROR_WEIGHT["exp_factor"])
144143
return length_weight + mean_cost * (1 - length_weight)
145144

146145
def score_single_sign(self, hypothesis: str, reference: str) -> float:

signwriting_evaluation/metrics/test_similarity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def test_score(self):
1717
def test_score_is_symemtric(self):
1818
reference = "M519x534S37900497x466S3770b497x485S15a51491x501S22f03481x513"
1919
hypothesis = "M530x538S37602508x462S15a11493x494S20e00488x510S22f03469x517"
20-
score1 = self.metric.score(hypothesis, reference)
21-
score2 = self.metric.score(reference, hypothesis)
20+
score1 = self.metric.score(hypothesis=hypothesis, reference=reference)
21+
score2 = self.metric.score(hypothesis=reference, reference=hypothesis)
2222
self.assertAlmostEqual(score1, score2, msg="The metric is not symmetric")
2323

2424
def test_score_jumbled_sign(self):

0 commit comments

Comments
 (0)