@@ -35,7 +35,7 @@ def get_shape_class_index(shape: int) -> Optional[int]:
3535
3636
3737@cache
38- def text_to_signs (text : str ) -> tuple [str ]:
38+ def text_to_signs (text : str ) -> tuple [str , ... ]:
3939 text_as_fsw = swu2fsw (text ) # converts swu symbols to fsw, while keeping the fsw symbols if present
4040 return tuple (normalize_signwriting (text_as_fsw ).split (" " ))
4141
@@ -49,70 +49,71 @@ def get_symbol_attributes(symbol: str) -> SymbolAttributes:
4949 return SymbolAttributes (shape , facing , angle , parallel )
5050
5151
52+ @cache
5253def fast_positional_distance (pos1 : Tuple [int , int ], pos2 : Tuple [int , int ]) -> float :
5354 # Unbelievably, this is faster than using numpy or scipy for simple Euclidean distance
5455 # It reduces the overhead of converting to numpy arrays when calculating distances
5556 dx = pos1 [0 ] - pos2 [0 ]
5657 dy = pos1 [1 ] - pos2 [1 ]
5758 return math .sqrt (dx * dx + dy * dy )
5859
60+
61+ ERROR_WEIGHT = {
62+ "shape" : 5 , # same weight as switching parallelization
63+ "facing" : 5 / 3 , # more important than angle, not as much as shape and orientation
64+ "angle" : 5 / 24 , # lowest importance out of the criteria
65+ "parallel" : 5 , # parallelization is 3 columns compare to 1 for the facing direction
66+ "positional" : 1 / 10 , # may be big values
67+ "normalized_factor" : 1 / 2.5 , # fitting shape of function
68+ "exp_factor" : 1.5 , # exponential distribution
69+ "class_penalty" : 100 , # big penalty for each class type passed
70+ }
71+
72+
73+ @cache
74+ def fast_symbol_distance (attributes1 : SymbolAttributes , attributes2 : SymbolAttributes ) -> float :
75+ d_shape = (attributes1 .shape - attributes2 .shape ) * ERROR_WEIGHT ["shape" ]
76+ d_facing = (attributes1 .facing - attributes2 .facing ) * ERROR_WEIGHT ["facing" ]
77+ d_angle = (attributes1 .angle - attributes2 .angle ) * ERROR_WEIGHT ["angle" ]
78+ d_parallel = (attributes1 .parallel != attributes2 .parallel ) * ERROR_WEIGHT ["parallel" ]
79+ return math .sqrt (d_shape * d_shape + \
80+ d_facing * d_facing + \
81+ d_angle * d_angle + \
82+ d_parallel * d_parallel )
83+
84+
5985fsw_to_sign = cache (fsw_to_sign )
6086
87+
6188class SignWritingSimilarityMetric (SignWritingMetric ):
6289 SYMMETRIC = True
6390
6491 def __init__ (self ):
6592 super ().__init__ ("SymbolsDistances" )
66- self .weight = {
67- "shape" : 5 , # same weight as switching parallelization
68- "facing" : 5 / 3 , # more important than angle, not as much as shape and orientation
69- "angle" : 5 / 24 , # lowest importance out of the criteria
70- "parallel" : 5 , # parallelization is 3 columns compare to 1 for the facing direction
71- "positional" : 1 / 10 , # may be big values
72- "normalized_factor" : 1 / 2.5 , # fitting shape of function
73- "exp_factor" : 1.5 , # exponential distribution
74- "class_penalty" : 100 , # big penalty for each class type passed
75- }
76-
7793 self .max_distance = self .calculate_distance ({"symbol" : "S10000" , "position" : (250 , 250 )},
7894 {"symbol" : "S38b07" , "position" : (750 , 750 )})
7995
80- def weight_vector (self , attributes : SymbolAttributes ) -> Tuple [float , ...]:
81- weighted_values = self .symbol_weight_vector * attributes
82- return weighted_values
83-
84- @cache
85- def symbol_distance (self , attributes1 : SymbolAttributes , attributes2 : SymbolAttributes ) -> float :
86- d_shape = (attributes1 .shape - attributes2 .shape ) * self .weight ["shape" ]
87- d_facing = (attributes1 .facing - attributes2 .facing ) * self .weight ["facing" ]
88- d_angle = (attributes1 .angle - attributes2 .angle ) * self .weight ["angle" ]
89- d_parallel = (attributes1 .parallel != attributes2 .parallel ) * self .weight ["parallel" ]
90- return math .sqrt (d_shape * d_shape + \
91- d_facing * d_facing + \
92- d_angle * d_angle + \
93- d_parallel * d_parallel )
94-
9596 def calculate_distance (self , hyp : SignSymbol , ref : SignSymbol ) -> float :
9697 hyp_attributes = get_symbol_attributes (hyp ['symbol' ])
9798 ref_attributes = get_symbol_attributes (ref ['symbol' ])
9899
99- symbols_distance = self . symbol_distance (hyp_attributes , ref_attributes )
100+ symbols_distance = fast_symbol_distance (hyp_attributes , ref_attributes )
100101
101102 position_euclidean = fast_positional_distance (hyp ["position" ], ref ["position" ])
102- position_distance = self . weight ["positional" ] * position_euclidean
103+ position_distance = ERROR_WEIGHT ["positional" ] * position_euclidean
103104
104105 hyp_class = get_shape_class_index (hyp_attributes .shape )
105106 ref_class = get_shape_class_index (ref_attributes .shape )
106107
107108 if hyp_class is None or ref_class is None :
108109 return self .max_distance
109110
110- class_penalty = abs (hyp_class - ref_class ) * self . weight ["class_penalty" ]
111+ class_penalty = abs (hyp_class - ref_class ) * ERROR_WEIGHT ["class_penalty" ]
111112
112113 return symbols_distance + position_distance + class_penalty
113114
114115 def normalized_distance (self , unnormalized : float ) -> float :
115- return pow (unnormalized / self .max_distance , self . weight ["normalized_factor" ])
116+ return pow (unnormalized / self .max_distance , ERROR_WEIGHT ["normalized_factor" ])
116117
117118 def symbols_score (self , hyp : SignSymbol , ref : SignSymbol ) -> float :
118119 distance = self .calculate_distance (hyp , ref )
@@ -135,12 +136,10 @@ def error_rate(self, hyp: Sign, ref: Sign) -> float:
135136 cost_matrix = cost_matrix .reshape (len (hyp ["symbols" ]), - 1 )
136137 # Find the lowest cost matching
137138 row_ind , col_ind = linear_sum_assignment (cost_matrix )
138- pairs = list (zip (row_ind , col_ind ))
139- # Print the matching and total cost
140- values = [cost_matrix [row , col ] for row , col in pairs ]
141- mean_cost = sum (values ) / len (values )
139+ mean_cost = float (cost_matrix [row_ind , col_ind ].mean ())
140+
142141 length_error = self .length_acc (hyp , ref )
143- length_weight = pow (length_error , self . weight ["exp_factor" ])
142+ length_weight = pow (length_error , ERROR_WEIGHT ["exp_factor" ])
144143 return length_weight + mean_cost * (1 - length_weight )
145144
146145 def score_single_sign (self , hypothesis : str , reference : str ) -> float :
0 commit comments