Skip to content

Commit a5a4600

Browse files
authored
Filter*ByAlignmentConfidenceJob: add flag to disregard alignments that didn't reach a final state (#608)
* Add job * Fix job Add relevant imports, remove irrelevant code * Remove job * Add `remove_dnf_alignments` flag to parse alignment logs function * Set flag to true in recording filtering * Simplify code as per offline review * Add DNF option to base segment filtering job Will be much more useful in the long run
1 parent 1129a59 commit a5a4600

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

corpus/filter.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,26 +128,37 @@ def run(self):
128128

129129

130130
class FilterSegmentsByAlignmentConfidenceJob(Job):
131+
__sis_hash_exclude__ = {"remove_dnf_alignments": False}
132+
131133
def __init__(
132134
self,
133135
alignment_logs: Dict[int, Path],
134136
percentile: float,
135137
crp: Optional[rasr.CommonRasrParameters] = None,
136138
plot: bool = True,
137139
absolute_threshold: Optional[float] = None,
140+
remove_dnf_alignments: bool = False,
138141
):
139142
"""
140143
:param alignment_logs: alignment_job.out_log_file; task_id -> log_file
141144
:param percentile: percent of alignment segments to keep. should be in (0,100]. for :func:`np.percentile`
142145
:param crp: used to set the number of output segments. if none, number of alignment log files is used instead.
143146
:param plot: plot the distribution of alignment scores
144147
:param absolute_threshold: alignments with score above this number are discarded
148+
:param remove_dnf_alignments: Whether alignments that haven't reached a final state
149+
should be considered in the final statistics dictionary.
150+
151+
Note that these alignments haven't made it to the final alignment caches,
152+
so parsing them is inconsistent with respect to the final caches
153+
and pollutes any statistics retrieved from the data.
154+
The default value is `False` only for retrocompatibility purposes, and `True` is recommended instead.
145155
"""
146156
self.alignment_logs = alignment_logs # alignment_job.log_file
147157
self.percentile = percentile
148158
self.absolute_threshold = absolute_threshold
149159
self.num_segments = len(alignment_logs) if crp is None else crp.concurrent
150160
self.plot = plot
161+
self.remove_dnf_alignments = remove_dnf_alignments
151162

152163
self.out_single_segment_files = dict(
153164
(i, self.output_path("segments.%d" % i)) for i in range(1, self.num_segments + 1)
@@ -157,8 +168,18 @@ def __init__(
157168
if plot:
158169
self.out_plot_avg = self.output_path("score.png")
159170

160-
def _parse_alignment_logs(self, alignment_logs: Dict[int, Path]) -> Dict[str, List[Tuple[str, float]]]:
171+
def _parse_alignment_logs(
172+
self, alignment_logs: Dict[int, Path], remove_dnf_alignments: bool = False
173+
) -> Dict[str, List[Tuple[str, float]]]:
161174
"""
175+
:param alignment_logs: Alignment logs to analyze.
176+
:param remove_dnf_alignments: Whether alignments that haven't reached a final state
177+
should be considered in the final statistics dictionary.
178+
179+
Note that these alignments haven't made it to the final alignment caches,
180+
so parsing them is inconsistent with respect to the final caches
181+
and pollutes any statistics retrieved from the data.
182+
The default value is `False` only for retrocompatibility purposes, and `True` is recommended instead.
162183
:return: Dictionary of recording full names to list of (segment full name, alignment score).
163184
164185
Note: the names adhere to the standards of the :class:`i6_core.lib.corpus.Recording`
@@ -173,6 +194,10 @@ def _parse_alignment_logs(self, alignment_logs: Dict[int, Path]) -> Dict[str, Li
173194
document = ET.parse(uopen(file_path))
174195
_seg_list = document.findall(".//segment")
175196
for seg in _seg_list:
197+
if remove_dnf_alignments and any(
198+
"Alignment did not reach any final state." in warning.text for warning in seg.findall(".//warning")
199+
):
200+
continue
176201
avg = seg.find(".//score/avg")
177202
full_seg_name = seg.attrib["full-name"]
178203
full_rec_name = "/".join(full_seg_name.split("/")[:-1])
@@ -271,7 +296,9 @@ def tasks(self):
271296
yield Task("run", resume="run", mini_task=True)
272297

273298
def run(self):
274-
recording_dict = self._parse_alignment_logs(self.alignment_logs)
299+
recording_dict = self._parse_alignment_logs(
300+
self.alignment_logs, remove_dnf_alignments=self.remove_dnf_alignments
301+
)
275302
avg_score_threshold = self._get_avg_score_threshold(recording_dict)
276303
filtered_segments = self._filter_segments(recording_dict, avg_score_threshold)
277304
self._write_output_segment_files(filtered_segments)
@@ -361,6 +388,14 @@ def _filter_segments(
361388

362389
return filtered_segments
363390

391+
def run(self):
392+
# Alignments that haven't reached a final state can bias the mean computation, so they're removed.
393+
recording_dict = self._parse_alignment_logs(self.alignment_logs, remove_dnf_alignments=True)
394+
avg_score_threshold = self._get_avg_score_threshold(recording_dict)
395+
filtered_segments = self._filter_segments(recording_dict, avg_score_threshold)
396+
self._write_output_segment_files(filtered_segments)
397+
self._plot(recording_dict)
398+
364399

365400
class FilterCorpusBySegmentsJob(Job):
366401
__sis_hash_exclude__ = {"delete_empty_recordings": False}

0 commit comments

Comments
 (0)