22 "FilterSegmentsByListJob" ,
33 "FilterSegmentsByRegexJob" ,
44 "FilterSegmentsByAlignmentConfidenceJob" ,
5+ "FilterRecordingsByAlignmentConfidenceJob" ,
56 "FilterCorpusBySegmentsJob" ,
67 "FilterCorpusRemoveUnknownWordSegmentsJob" ,
78 "FilterCorpusBySegmentDurationJob" ,
89]
910
11+ from collections import defaultdict
1012import gzip
1113import logging
1214import numpy as np
1315import re
1416import xml .etree .cElementTree as ET
15- from typing import Dict , List , Optional , Union
17+ from typing import Dict , List , Optional , Tuple , Union
1618
1719from i6_core import rasr
1820from i6_core .lib import corpus
@@ -155,24 +157,54 @@ def __init__(
155157 if plot :
156158 self .out_plot_avg = self .output_path ("score.png" )
157159
158- def tasks (self ):
159- yield Task ("run" , resume = "run" , mini_task = True )
160+ def _parse_alignment_logs (self , alignment_logs : Dict [int , Path ]) -> Dict [str , List [Tuple [str , float ]]]:
161+ """
162+ :return: Dictionary of recording full names to list of (segment full name, alignment score).
160163
161- def run (self ):
162- segment_dict = {}
163- for task_id , log_file in self .alignment_logs .items ():
164+ Note: the names adhere to the standards of the :class:`i6_core.lib.corpus.Recording`
165+ and :class:`i6_core.lib.corpus.Segment` classes,
166+ in which the segment name is appended to the full recording name (joined by a slash)
167+ to make the full segment name.
168+ """
169+ recording_dict : Dict [str , List [Tuple [str , float ]]] = defaultdict (list )
170+ for _ , log_file in alignment_logs .items ():
164171 logging .info ("Reading: {}" .format (log_file ))
165172 file_path = tk .uncached_path (log_file )
166173 document = ET .parse (uopen (file_path ))
167174 _seg_list = document .findall (".//segment" )
168175 for seg in _seg_list :
169176 avg = seg .find (".//score/avg" )
170- segment_dict [seg .attrib ["full-name" ]] = float (avg .text )
177+ full_seg_name = seg .attrib ["full-name" ]
178+ full_rec_name = "/" .join (full_seg_name .split ("/" )[:- 1 ])
179+ recording_dict [full_rec_name ].append ((full_seg_name , float (avg .text )))
171180 del document
181+ logging .info ("Scores has {} entries." .format (len (recording_dict )))
182+
183+ return recording_dict
184+
185+ def _get_alignment_scores_array (self , recording_dict : Dict [str , List [Tuple [str , float ]]]) -> np .array :
186+ """
187+ :param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
188+ :return: Array with the alignment confidence scores **per segment**.
189+ """
190+ return np .asarray (
191+ [
192+ alignment_score
193+ for seg_name_and_score in recording_dict .values ()
194+ for (_ , alignment_score ) in seg_name_and_score
195+ ]
196+ )
172197
173- logging .info ("Scores has {} entries." .format (len (segment_dict )))
174- score_np = np .asarray (list (segment_dict .values ()))
198+ def _get_avg_score_threshold (self , recording_dict : Dict [str , List [Tuple [str , float ]]]) -> float :
199+ """
200+ :param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
201+ :return: Alignment score threshold below which samples should be kept,
202+ and above which samples should be discarded.
203+ It's calculated according to the `percentile` and `absolute_threshold` values provided by the user.
204+ """
205+ score_np = self ._get_alignment_scores_array (recording_dict )
175206 logging .info ("Max {}; Min {}; Median {}" .format (score_np .max (), score_np .min (), np .median (score_np )))
207+
176208 avg_score_threshold = np .percentile (score_np , self .percentile )
177209 if np .isnan (avg_score_threshold ):
178210 avg_score_threshold = np .inf
@@ -181,24 +213,29 @@ def run(self):
181213 avg_score_threshold = min (avg_score_threshold , self .absolute_threshold )
182214 logging .info ("Threshold is {}" .format (avg_score_threshold ))
183215
184- if self .plot :
185- import matplotlib
186-
187- matplotlib .use ("Agg" )
188- import matplotlib .pyplot as plt
216+ return avg_score_threshold
189217
190- plot_percentile = np .percentile (score_np , 90 ) # there can be huge outliers
191- np .clip (score_np , 0 , 200 , out = score_np )
192- plt .hist (score_np , bins = 100 , range = (0 , 200 ))
193- plt .xlabel ("Average Maximum-Likelihood Score" )
194- plt .ylabel ("Number of Segments" )
195- plt .title ("Histogram of Alignment Scores" )
196- plt .savefig (fname = self .out_plot_avg .get_path ())
197-
198- # Only keep segments that are below the threshold
199- filtered_segments = [seg for seg , avg in segment_dict .items () if avg <= avg_score_threshold ]
218+ def _filter_segments (
219+ self , recording_dict : Dict [str , List [Tuple [str , float ]]], avg_score_threshold : float
220+ ) -> List [str ]:
221+ """
222+ :param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
223+ :param avg_score_threshold: Alignment score threshold below which samples should be kept,
224+ and above which samples should be discarded.
225+ :return: List of segments (represented by their full name) that should be kept.
226+ """
227+ # Only keep segments that are below the threshold.
228+ filtered_segments = [
229+ seg for seg_avg in recording_dict .values () for (seg , avg ) in seg_avg if avg <= avg_score_threshold
230+ ]
200231 logging .info ("Have {} entries after filtering." .format (len (filtered_segments )))
201232
233+ return filtered_segments
234+
235+ def _write_output_segment_files (self , filtered_segments : List [str ]):
236+ """
237+ :param filtered_segments: List of segments (represented by their full name) that should be kept.
238+ """
202239 for idx , segments in enumerate (chunks (filtered_segments , self .num_segments )):
203240 with open (self .out_single_segment_files [idx + 1 ].get_path (), "wt" ) as segment_file :
204241 for segment in segments :
@@ -208,6 +245,122 @@ def run(self):
208245 for segment in filtered_segments :
209246 segment_file .write (segment + "\n " )
210247
248+ def _plot (self , recording_dict : Dict [str , List [Tuple [str , float ]]]):
249+ """
250+ Plots an alignment score.
251+
252+ Note: the plot only takes into account strictly positive values.
253+ For more customizable plotting, it's suggested to use :class:`i6_core.mm.alignment.PlotAlignmentJob` instead.
254+ """
255+ import matplotlib
256+ import matplotlib .pyplot as plt
257+
258+ matplotlib .use ("Agg" )
259+
260+ score_np = self ._get_alignment_scores_array (recording_dict )
261+
262+ # Before filtering.
263+ np .clip (score_np , 0 , 200 , out = score_np )
264+ plt .hist (score_np , bins = 100 , range = (0 , 200 ))
265+ plt .xlabel ("Average Maximum-Likelihood Score" )
266+ plt .ylabel ("Number of Segments" )
267+ plt .title ("Histogram of Alignment Scores" )
268+ plt .savefig (fname = self .out_plot_avg .get_path ())
269+
270+ def tasks (self ):
271+ yield Task ("run" , resume = "run" , mini_task = True )
272+
273+ def run (self ):
274+ recording_dict = self ._parse_alignment_logs (self .alignment_logs )
275+ avg_score_threshold = self ._get_avg_score_threshold (recording_dict )
276+ filtered_segments = self ._filter_segments (recording_dict , avg_score_threshold )
277+ self ._write_output_segment_files (filtered_segments )
278+ self ._plot (recording_dict )
279+
280+
281+ class FilterRecordingsByAlignmentConfidenceJob (FilterSegmentsByAlignmentConfidenceJob ):
282+ """
283+ Filter segments like :class:`FilterSegmentsByAlignmentConfidenceJob` does.
284+ However, instead of taking into account the alignment confidence of a single segment,
285+ take into account the average alignment confidence of the whole recording.
286+ """
287+
288+ def __init__ (
289+ self ,
290+ alignment_logs : Dict [int , Path ],
291+ percentile : float ,
292+ crp : Optional [rasr .CommonRasrParameters ] = None ,
293+ plot : bool = True ,
294+ absolute_threshold : Optional [float ] = None ,
295+ ):
296+ """
297+ :param alignment_logs: Mapping of `task_id` into log file.
298+ Can be directly used as the output `out_log_file` of the job :class:`i6_core.mm.AlignmentJob`.
299+ :param percentile: Percent of recordings whose segments should be keep, in the range `(0,100]`.
300+ Used directly in :func:`np.percentile`.
301+ :param crp: Used to set the number of output segments.
302+ If `None` (default value), all segments in all alignment log files are considered.
303+ :param plot: Whether to plot the distribution of alignment scores.
304+ :param absolute_threshold: All segments from a recording are discarded
305+ if the recording's average alignment score is above this number.
306+ """
307+ super ().__init__ (
308+ alignment_logs = alignment_logs ,
309+ percentile = percentile ,
310+ crp = crp ,
311+ plot = plot ,
312+ absolute_threshold = absolute_threshold ,
313+ )
314+
315+ self .out_kept_recordings = self .output_path ("kept_recordings.txt" )
316+ self .out_discarded_recordings = self .output_path ("discarded_recordings.txt" )
317+
318+ def _get_avg_confidence_per_recording (self , recording_dict : Dict [str , List [Tuple [str , float ]]]) -> Dict [str , float ]:
319+ """
320+ :param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
321+ :return: Dictionary of recording full names to average recording alignment score
322+ (calculated as the average of all alignment scores of the segments that compose the recording).
323+ """
324+ return {
325+ full_rec_name : np .average ([conf for (_ , conf ) in seg_and_confs ])
326+ for full_rec_name , seg_and_confs in recording_dict .items ()
327+ }
328+
329+ def _get_alignment_scores_array (self , recording_dict : Dict [str , List [Tuple [str , float ]]]) -> np .array :
330+ """
331+ :param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
332+ :return: Array with the alignment confidence scores **per recording**.
333+ """
334+ return np .asarray (list (self ._get_avg_confidence_per_recording (recording_dict ).values ()))
335+
336+ def _filter_segments (
337+ self , recording_dict : Dict [str , List [Tuple [str , float ]]], avg_score_threshold : float
338+ ) -> List [str ]:
339+ """
340+ :param recording_dict: Dictionary of recording full names to list of (segment full name, alignment score).
341+ :param avg_score_threshold: Alignment score threshold below which samples should be kept,
342+ and above which samples should be discarded.
343+ :return: List of segments (represented by their full name) that should be kept.
344+ """
345+ recording_to_average_conf = self ._get_avg_confidence_per_recording (recording_dict )
346+
347+ filtered_segments = []
348+ # Write outputs that are local to this job here to avoid passing more variables around.
349+ with uopen (self .out_kept_recordings .get_path (), "wt" ) as f_kept , uopen (
350+ self .out_discarded_recordings .get_path (), "wt"
351+ ) as f_discarded :
352+ for full_rec_name , avg_alignment_score in recording_to_average_conf .items ():
353+ if avg_alignment_score <= avg_score_threshold :
354+ # Keep the whole recording.
355+ f_kept .write (f"{ full_rec_name } { avg_alignment_score } \n " )
356+ for segment_name , _ in recording_dict [full_rec_name ]:
357+ filtered_segments .append (segment_name )
358+ else :
359+ # Discard the whole recording.
360+ f_discarded .write (f"{ full_rec_name } { avg_alignment_score } \n " )
361+
362+ return filtered_segments
363+
211364
212365class FilterCorpusBySegmentsJob (Job ):
213366 __sis_hash_exclude__ = {"delete_empty_recordings" : False }
0 commit comments