Skip to content

Commit 4795d0a

Browse files
authored
Add FilterRecordingsByAlignmentConfidenceJob (#603)
Same as `FilterSegmentsByAlignmentConfidenceJob`, but filters all segments inside a given recording if the average recording confidence score (defined as `avg(segment.score for segment in recording)`) is greater than the thresholds defined by the user.
1 parent 42fc209 commit 4795d0a

File tree

1 file changed

+177
-24
lines changed

1 file changed

+177
-24
lines changed

corpus/filter.py

Lines changed: 177 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
"FilterSegmentsByListJob",
33
"FilterSegmentsByRegexJob",
44
"FilterSegmentsByAlignmentConfidenceJob",
5+
"FilterRecordingsByAlignmentConfidenceJob",
56
"FilterCorpusBySegmentsJob",
67
"FilterCorpusRemoveUnknownWordSegmentsJob",
78
"FilterCorpusBySegmentDurationJob",
89
]
910

11+
from collections import defaultdict
1012
import gzip
1113
import logging
1214
import numpy as np
1315
import re
1416
import xml.etree.cElementTree as ET
15-
from typing import Dict, List, Optional, Union
17+
from typing import Dict, List, Optional, Tuple, Union
1618

1719
from i6_core import rasr
1820
from i6_core.lib import corpus
@@ -155,24 +157,54 @@ def __init__(
155157
if plot:
156158
self.out_plot_avg = self.output_path("score.png")
157159

158-
def tasks(self):
159-
yield Task("run", resume="run", mini_task=True)
160+
def _parse_alignment_logs(self, alignment_logs: Dict[int, Path]) -> Dict[str, List[Tuple[str, float]]]:
161+
"""
162+
:return: Dictionary of recording full names to list of (segment full name, alignment score).
160163
161-
def run(self):
162-
segment_dict = {}
163-
for task_id, log_file in self.alignment_logs.items():
164+
Note: the names adhere to the standards of the :class:`i6_core.lib.corpus.Recording`
165+
and :class:`i6_core.lib.corpus.Segment` classes,
166+
in which the segment name is appended to the full recording name (joined by a slash)
167+
to make the full segment name.
168+
"""
169+
recording_dict: Dict[str, List[Tuple[str, float]]] = defaultdict(list)
170+
for _, log_file in alignment_logs.items():
164171
logging.info("Reading: {}".format(log_file))
165172
file_path = tk.uncached_path(log_file)
166173
document = ET.parse(uopen(file_path))
167174
_seg_list = document.findall(".//segment")
168175
for seg in _seg_list:
169176
avg = seg.find(".//score/avg")
170-
segment_dict[seg.attrib["full-name"]] = float(avg.text)
177+
full_seg_name = seg.attrib["full-name"]
178+
full_rec_name = "/".join(full_seg_name.split("/")[:-1])
179+
recording_dict[full_rec_name].append((full_seg_name, float(avg.text)))
171180
del document
181+
logging.info("Scores has {} entries.".format(len(recording_dict)))
182+
183+
return recording_dict
184+
185+
def _get_alignment_scores_array(self, recording_dict: Dict[str, List[Tuple[str, float]]]) -> np.array:
186+
"""
187+
:param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
188+
:return: Array with the alignment confidence scores **per segment**.
189+
"""
190+
return np.asarray(
191+
[
192+
alignment_score
193+
for seg_name_and_score in recording_dict.values()
194+
for (_, alignment_score) in seg_name_and_score
195+
]
196+
)
172197

173-
logging.info("Scores has {} entries.".format(len(segment_dict)))
174-
score_np = np.asarray(list(segment_dict.values()))
198+
def _get_avg_score_threshold(self, recording_dict: Dict[str, List[Tuple[str, float]]]) -> float:
199+
"""
200+
:param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
201+
:return: Alignment score threshold below which samples should be kept,
202+
and above which samples should be discarded.
203+
It's calculated according to the `percentile` and `absolute_threshold` values provided by the user.
204+
"""
205+
score_np = self._get_alignment_scores_array(recording_dict)
175206
logging.info("Max {}; Min {}; Median {}".format(score_np.max(), score_np.min(), np.median(score_np)))
207+
176208
avg_score_threshold = np.percentile(score_np, self.percentile)
177209
if np.isnan(avg_score_threshold):
178210
avg_score_threshold = np.inf
@@ -181,24 +213,29 @@ def run(self):
181213
avg_score_threshold = min(avg_score_threshold, self.absolute_threshold)
182214
logging.info("Threshold is {}".format(avg_score_threshold))
183215

184-
if self.plot:
185-
import matplotlib
186-
187-
matplotlib.use("Agg")
188-
import matplotlib.pyplot as plt
216+
return avg_score_threshold
189217

190-
plot_percentile = np.percentile(score_np, 90) # there can be huge outliers
191-
np.clip(score_np, 0, 200, out=score_np)
192-
plt.hist(score_np, bins=100, range=(0, 200))
193-
plt.xlabel("Average Maximum-Likelihood Score")
194-
plt.ylabel("Number of Segments")
195-
plt.title("Histogram of Alignment Scores")
196-
plt.savefig(fname=self.out_plot_avg.get_path())
197-
198-
# Only keep segments that are below the threshold
199-
filtered_segments = [seg for seg, avg in segment_dict.items() if avg <= avg_score_threshold]
218+
def _filter_segments(
219+
self, recording_dict: Dict[str, List[Tuple[str, float]]], avg_score_threshold: float
220+
) -> List[str]:
221+
"""
222+
:param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
223+
:param avg_score_threshold: Alignment score threshold below which samples should be kept,
224+
and above which samples should be discarded.
225+
:return: List of segments (represented by their full name) that should be kept.
226+
"""
227+
# Only keep segments that are below the threshold.
228+
filtered_segments = [
229+
seg for seg_avg in recording_dict.values() for (seg, avg) in seg_avg if avg <= avg_score_threshold
230+
]
200231
logging.info("Have {} entries after filtering.".format(len(filtered_segments)))
201232

233+
return filtered_segments
234+
235+
def _write_output_segment_files(self, filtered_segments: List[str]):
236+
"""
237+
:param filtered_segments: List of segments (represented by their full name) that should be kept.
238+
"""
202239
for idx, segments in enumerate(chunks(filtered_segments, self.num_segments)):
203240
with open(self.out_single_segment_files[idx + 1].get_path(), "wt") as segment_file:
204241
for segment in segments:
@@ -208,6 +245,122 @@ def run(self):
208245
for segment in filtered_segments:
209246
segment_file.write(segment + "\n")
210247

248+
def _plot(self, recording_dict: Dict[str, List[Tuple[str, float]]]):
249+
"""
250+
Plots an alignment score.
251+
252+
Note: the plot only takes into account strictly positive values.
253+
For more customizable plotting, it's suggested to use :class:`i6_core.mm.alignment.PlotAlignmentJob` instead.
254+
"""
255+
import matplotlib
256+
import matplotlib.pyplot as plt
257+
258+
matplotlib.use("Agg")
259+
260+
score_np = self._get_alignment_scores_array(recording_dict)
261+
262+
# Before filtering.
263+
np.clip(score_np, 0, 200, out=score_np)
264+
plt.hist(score_np, bins=100, range=(0, 200))
265+
plt.xlabel("Average Maximum-Likelihood Score")
266+
plt.ylabel("Number of Segments")
267+
plt.title("Histogram of Alignment Scores")
268+
plt.savefig(fname=self.out_plot_avg.get_path())
269+
270+
def tasks(self):
271+
yield Task("run", resume="run", mini_task=True)
272+
273+
def run(self):
274+
recording_dict = self._parse_alignment_logs(self.alignment_logs)
275+
avg_score_threshold = self._get_avg_score_threshold(recording_dict)
276+
filtered_segments = self._filter_segments(recording_dict, avg_score_threshold)
277+
self._write_output_segment_files(filtered_segments)
278+
self._plot(recording_dict)
279+
280+
281+
class FilterRecordingsByAlignmentConfidenceJob(FilterSegmentsByAlignmentConfidenceJob):
282+
"""
283+
Filter segments like :class:`FilterSegmentsByAlignmentConfidenceJob` does.
284+
However, instead of taking into account the alignment confidence of a single segment,
285+
take into account the average alignment confidence of the whole recording.
286+
"""
287+
288+
def __init__(
289+
self,
290+
alignment_logs: Dict[int, Path],
291+
percentile: float,
292+
crp: Optional[rasr.CommonRasrParameters] = None,
293+
plot: bool = True,
294+
absolute_threshold: Optional[float] = None,
295+
):
296+
"""
297+
:param alignment_logs: Mapping of `task_id` into log file.
298+
Can be directly used as the output `out_log_file` of the job :class:`i6_core.mm.AlignmentJob`.
299+
:param percentile: Percent of recordings whose segments should be keep, in the range `(0,100]`.
300+
Used directly in :func:`np.percentile`.
301+
:param crp: Used to set the number of output segments.
302+
If `None` (default value), all segments in all alignment log files are considered.
303+
:param plot: Whether to plot the distribution of alignment scores.
304+
:param absolute_threshold: All segments from a recording are discarded
305+
if the recording's average alignment score is above this number.
306+
"""
307+
super().__init__(
308+
alignment_logs=alignment_logs,
309+
percentile=percentile,
310+
crp=crp,
311+
plot=plot,
312+
absolute_threshold=absolute_threshold,
313+
)
314+
315+
self.out_kept_recordings = self.output_path("kept_recordings.txt")
316+
self.out_discarded_recordings = self.output_path("discarded_recordings.txt")
317+
318+
def _get_avg_confidence_per_recording(self, recording_dict: Dict[str, List[Tuple[str, float]]]) -> Dict[str, float]:
319+
"""
320+
:param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
321+
:return: Dictionary of recording full names to average recording alignment score
322+
(calculated as the average of all alignment scores of the segments that compose the recording).
323+
"""
324+
return {
325+
full_rec_name: np.average([conf for (_, conf) in seg_and_confs])
326+
for full_rec_name, seg_and_confs in recording_dict.items()
327+
}
328+
329+
def _get_alignment_scores_array(self, recording_dict: Dict[str, List[Tuple[str, float]]]) -> np.array:
330+
"""
331+
:param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
332+
:return: Array with the alignment confidence scores **per recording**.
333+
"""
334+
return np.asarray(list(self._get_avg_confidence_per_recording(recording_dict).values()))
335+
336+
def _filter_segments(
337+
self, recording_dict: Dict[str, List[Tuple[str, float]]], avg_score_threshold: float
338+
) -> List[str]:
339+
"""
340+
:param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
341+
:param avg_score_threshold: Alignment score threshold below which samples should be kept,
342+
and above which samples should be discarded.
343+
:return: List of segments (represented by their full name) that should be kept.
344+
"""
345+
recording_to_average_conf = self._get_avg_confidence_per_recording(recording_dict)
346+
347+
filtered_segments = []
348+
# Write outputs that are local to this job here to avoid passing more variables around.
349+
with uopen(self.out_kept_recordings.get_path(), "wt") as f_kept, uopen(
350+
self.out_discarded_recordings.get_path(), "wt"
351+
) as f_discarded:
352+
for full_rec_name, avg_alignment_score in recording_to_average_conf.items():
353+
if avg_alignment_score <= avg_score_threshold:
354+
# Keep the whole recording.
355+
f_kept.write(f"{full_rec_name} {avg_alignment_score}\n")
356+
for segment_name, _ in recording_dict[full_rec_name]:
357+
filtered_segments.append(segment_name)
358+
else:
359+
# Discard the whole recording.
360+
f_discarded.write(f"{full_rec_name} {avg_alignment_score}\n")
361+
362+
return filtered_segments
363+
211364

212365
class FilterCorpusBySegmentsJob(Job):
213366
__sis_hash_exclude__ = {"delete_empty_recordings": False}

0 commit comments

Comments
 (0)