Skip to content

Commit f9a9f39

Browse files
albertzvieting
andauthored
ScliteJob, precision_ndigit option (#442)
Co-authored-by: vieting <[email protected]>
1 parent 6bb893f commit f9a9f39

File tree

2 files changed

+95
-27
lines changed

2 files changed

+95
-27
lines changed

recognition/scoring.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tempfile
1313
import collections
1414
import re
15-
from typing import List, Optional
15+
from typing import List, Optional, Dict, Tuple
1616

1717
from sisyphus import *
1818
from i6_core.lib.corpus import *
@@ -51,7 +51,7 @@ class ScliteJob(Job):
5151
- out_*: the job also outputs many variables, please look in the init code for a list
5252
"""
5353

54-
__sis_hash_exclude__ = {"sctk_binary_path": None}
54+
__sis_hash_exclude__ = {"sctk_binary_path": None, "precision_ndigit": 1}
5555

5656
def __init__(
5757
self,
@@ -61,6 +61,7 @@ def __init__(
6161
sort_files: bool = False,
6262
additional_args: Optional[List[str]] = None,
6363
sctk_binary_path: Optional[tk.Path] = None,
64+
precision_ndigit: Optional[int] = 1,
6465
):
6566
"""
6667
:param ref: reference stm text file
@@ -69,6 +70,12 @@ def __init__(
6970
:param sort_files: sort ctm and stm before scoring
7071
:param additional_args: additional command line arguments passed to the Sclite binary call
7172
:param sctk_binary_path: set an explicit binary path.
73+
:param precision_ndigit: number of digits after decimal point for the precision
74+
of the percentages in the output variables.
75+
If None, no rounding is done.
76+
In sclite, the precision was always one digit after the decimal point
77+
(https://github.com/usnistgov/SCTK/blob/f48376a203ab17f/src/sclite/sc_dtl.c#L343),
78+
thus we recalculate the percentages here.
7279
"""
7380
self.set_vis_name("Sclite - %s" % ("CER" if cer else "WER"))
7481

@@ -78,6 +85,7 @@ def __init__(
7885
self.sort_files = sort_files
7986
self.additional_args = additional_args
8087
self.sctk_binary_path = sctk_binary_path
88+
self.precision_ndigit = precision_ndigit
8189

8290
self.out_report_dir = self.output_path("reports", True)
8391

@@ -149,31 +157,66 @@ def run(self, output_to_report_dir=True):
149157

150158
if output_to_report_dir: # run as real job
151159
with open(f"{output_dir}/sclite.dtl", "rt", errors="ignore") as f:
160+
# Example:
161+
"""
162+
Percent Total Error = 5.3% (2709)
163+
...
164+
Percent Word Accuracy = 94.7%
165+
...
166+
Ref. words = (50948)
167+
"""
168+
169+
# key -> percentage, absolute
170+
output_variables: Dict[str, Tuple[Optional[tk.Variable], Optional[tk.Variable]]] = {
171+
"Percent Total Error": (self.out_wer, self.out_num_errors),
172+
"Percent Correct": (self.out_percent_correct, self.out_num_correct),
173+
"Percent Substitution": (self.out_percent_substitution, self.out_num_substitution),
174+
"Percent Deletions": (self.out_percent_deletions, self.out_num_deletions),
175+
"Percent Insertions": (self.out_percent_insertions, self.out_num_insertions),
176+
"Percent Word Accuracy": (self.out_percent_word_accuracy, None),
177+
"Ref. words": (None, self.out_ref_words),
178+
"Hyp. words": (None, self.out_hyp_words),
179+
"Aligned words": (None, self.out_aligned_words),
180+
}
181+
182+
outputs_absolute: Dict[str, int] = {}
152183
for line in f:
153-
s = line.split()
154-
if line.startswith("Percent Total Error"):
155-
self.out_wer.set(float(s[4][:-1]))
156-
self.out_num_errors.set(int("".join(s[5:])[1:-1]))
157-
elif line.startswith("Percent Correct"):
158-
self.out_percent_correct.set(float(s[3][:-1]))
159-
self.out_num_correct.set(int("".join(s[4:])[1:-1]))
160-
elif line.startswith("Percent Substitution"):
161-
self.out_percent_substitution.set(float(s[3][:-1]))
162-
self.out_num_substitution.set(int("".join(s[4:])[1:-1]))
163-
elif line.startswith("Percent Deletions"):
164-
self.out_percent_deletions.set(float(s[3][:-1]))
165-
self.out_num_deletions.set(int("".join(s[4:])[1:-1]))
166-
elif line.startswith("Percent Insertions"):
167-
self.out_percent_insertions.set(float(s[3][:-1]))
168-
self.out_num_insertions.set(int("".join(s[4:])[1:-1]))
169-
elif line.startswith("Percent Word Accuracy"):
170-
self.out_percent_word_accuracy.set(float(s[4][:-1]))
171-
elif line.startswith("Ref. words"):
172-
self.out_ref_words.set(int("".join(s[3:])[1:-1]))
173-
elif line.startswith("Hyp. words"):
174-
self.out_hyp_words.set(int("".join(s[3:])[1:-1]))
175-
elif line.startswith("Aligned words"):
176-
self.out_aligned_words.set(int("".join(s[3:])[1:-1]))
184+
key: Optional[str] = ([key for key in output_variables if line.startswith(key)] or [None])[0]
185+
if not key:
186+
continue
187+
pattern = rf"^{re.escape(key)}\s*=\s*((\S+)%)?\s*(\(\s*(\d+)\))?$"
188+
m = re.match(pattern, line)
189+
assert m, f"Could not parse line: {line!r}, does not match to pattern r'{pattern}'"
190+
absolute_s = m.group(4)
191+
if not absolute_s:
192+
assert not output_variables[key][1], f"Expected absolute value for {key}"
193+
continue
194+
outputs_absolute[key] = int(absolute_s)
195+
if key == "Aligned words":
196+
break # that should be the last key, can stop now
197+
198+
assert "Ref. words" in outputs_absolute, "Expected absolute numbers for Ref. words"
199+
num_ref_words = outputs_absolute["Ref. words"]
200+
assert "Percent Total Error" in outputs_absolute, "Expected absolute numbers for Percent Total Error"
201+
outputs_absolute["Percent Word Accuracy"] = num_ref_words - outputs_absolute["Percent Total Error"]
202+
203+
outputs_percentage: Dict[str, float] = {}
204+
for key, absolute in outputs_absolute.items():
205+
if num_ref_words > 0:
206+
percentage = 100.0 * absolute / num_ref_words
207+
else:
208+
percentage = float("nan")
209+
outputs_percentage[key] = (
210+
round(percentage, self.precision_ndigit) if self.precision_ndigit is not None else percentage
211+
)
212+
213+
for key, (percentage_var, absolute_var) in output_variables.items():
214+
if percentage_var is not None:
215+
assert key in outputs_percentage, f"Expected percentage value for {key}"
216+
percentage_var.set(outputs_percentage[key])
217+
if absolute_var is not None:
218+
assert key in outputs_absolute, f"Expected absolute value for {key}"
219+
absolute_var.set(outputs_absolute[key])
177220

178221
def calc_wer(self):
179222
wer = None

tests/job_tests/recognition/test_scoring.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_sclite_job():
7070
assert sclite_job.out_num_deletions.get() == 2, "Wrong num deletions, %s instead of 2" % str(
7171
sclite_job.out_num_deletions.get()
7272
)
73-
assert sclite_job.out_percent_insertions.get() == 5.9, "Wrong percent insertions, %s instead of 4.5" % str(
73+
assert sclite_job.out_percent_insertions.get() == 5.9, "Wrong percent insertions, %s instead of 5.9" % str(
7474
sclite_job.out_percent_insertions.get()
7575
)
7676
assert sclite_job.out_num_insertions.get() == 1, "Wrong num insertions, %s instead of 1" % str(
@@ -88,3 +88,28 @@ def test_sclite_job():
8888
assert sclite_job.out_aligned_words.get() == 18, "Wrong num aligned words, %s instead of 18" % str(
8989
sclite_job.out_aligned_words.get()
9090
)
91+
92+
# Now test custom precision.
93+
94+
sclite_job = ScliteJob(ref=ref, hyp=hyp, sctk_binary_path=sctk_binary, precision_ndigit=2)
95+
sclite_job._sis_setup_directory()
96+
sclite_job.run()
97+
98+
assert sclite_job.out_wer.get() == 58.82, "Wrong WER, %s instead of 58.82" % str(sclite_job.out_wer.get())
99+
100+
assert sclite_job.out_percent_correct.get() == 47.06, "Wrong percent correct, %s instead of 47.06" % str(
101+
sclite_job.out_percent_correct.get()
102+
)
103+
assert (
104+
sclite_job.out_percent_substitution.get() == 41.18
105+
), "Wrong percent substitution, %s instead of 41.18" % str(sclite_job.out_percent_substitution.get())
106+
107+
assert sclite_job.out_percent_deletions.get() == 11.76, "Wrong percent deletions, %s instead of 11.76" % str(
108+
sclite_job.out_percent_deletions.get()
109+
)
110+
assert sclite_job.out_percent_insertions.get() == 5.88, "Wrong percent insertions, %s instead of 5.88" % str(
111+
sclite_job.out_percent_insertions.get()
112+
)
113+
assert (
114+
sclite_job.out_percent_word_accuracy.get() == 41.18
115+
), "Wrong percent word accuracy, %s instead of 41.18" % str(sclite_job.out_percent_word_accuracy.get())

0 commit comments

Comments
 (0)