diff --git a/lm/__init__.py b/lm/__init__.py index b5960e96b..944dd5893 100644 --- a/lm/__init__.py +++ b/lm/__init__.py @@ -3,3 +3,4 @@ from .reverse_arpa import * from .vocabulary import * from .srilm import * +from .util import * diff --git a/lm/lm_image.py b/lm/lm_image.py index f5de4ac4e..1847ae670 100644 --- a/lm/lm_image.py +++ b/lm/lm_image.py @@ -21,7 +21,7 @@ def __init__( extra_config=None, extra_post_config=None, encoding="utf-8", - mem=2, + mem=12, ): kwargs = locals() del kwargs["self"] diff --git a/lm/util.py b/lm/util.py new file mode 100644 index 000000000..658cf5e52 --- /dev/null +++ b/lm/util.py @@ -0,0 +1,26 @@ +from typing import List, Optional, Tuple + +import i6_core.rasr as rasr + + +def _has_image(config: rasr.RasrConfig, post_config: Optional[rasr.RasrConfig]): + res = config._get("image") is not None + res = res or (post_config is not None and post_config._get("image") is not None) + return res + + +def find_arpa_lms( + lm_config: rasr.RasrConfig, lm_post_config: Optional[rasr.RasrConfig] = None +) -> List[Tuple[rasr.RasrConfig, rasr.RasrConfig]]: + result = [] + + if lm_config.type == "ARPA": + if not _has_image(lm_config, lm_post_config): + result.append((lm_config, lm_post_config)) + elif lm_config.type == "combine": + for i in range(1, lm_config.num_lms + 1): + sub_lm_config = lm_config[f"lm-{i}"] + sub_lm_post_config = lm_post_config[f"lm-{i}"] if lm_post_config is not None else None + result += find_arpa_lms(sub_lm_config, sub_lm_post_config) + + return result diff --git a/rasr/crp.py b/rasr/crp.py index 68029d903..1a74baa76 100644 --- a/rasr/crp.py +++ b/rasr/crp.py @@ -79,7 +79,7 @@ def set_executables(self, rasr_binary_path, rasr_arch="linux-x86_64-standard"): self.flf_tool_exe = rasr_binary_path.join_right(f"flf-tool.{rasr_arch}") self.kws_tool_exe = None # does not exist self.lattice_processor_exe = rasr_binary_path.join_right(f"lattice-processor.{rasr_arch}") - self.lm_util_exe = None # does not exist + self.lm_util_exe = rasr_binary_path.join_right(f"lm-util.{rasr_arch}") self.nn_trainer_exe = rasr_binary_path.join_right(f"nn-trainer.{rasr_arch}") self.speech_recognizer_exe = rasr_binary_path.join_right(f"speech-recognizer.{rasr_arch}") diff --git a/recognition/advanced_tree_search.py b/recognition/advanced_tree_search.py index 0e91f89c9..52fd6d586 100644 --- a/recognition/advanced_tree_search.py +++ b/recognition/advanced_tree_search.py @@ -10,6 +10,8 @@ Path = setup_path(__package__) +import copy +from enum import Enum import math import os import shutil @@ -69,25 +71,6 @@ def run(self): def cleanup_before_run(self, cmd, retry, *args): util.backup_if_exists("lm_and_state_tree.log") - @classmethod - def find_arpa_lms(cls, lm_config, lm_post_config=None): - result = [] - - def has_image(c, pc): - res = c._get("image") is not None - res = res or (pc is not None and pc._get("image") is not None) - return res - - if lm_config.type == "ARPA": - if not has_image(lm_config, lm_post_config): - result.append((lm_config, lm_post_config)) - elif lm_config.type == "combine": - for i in range(1, lm_config.num_lms + 1): - sub_lm_config = lm_config["lm-%d" % i] - sub_lm_post_config = lm_post_config["lm-%d" % i] if lm_post_config is not None else None - result += cls.find_arpa_lms(sub_lm_config, sub_lm_post_config) - return result - @classmethod def create_config(cls, crp, feature_scorer, extra_config, extra_post_config, **kwargs): config, post_config = rasr.build_config_from_mapping( @@ -117,7 +100,7 @@ def create_config(cls, crp, feature_scorer, extra_config, extra_post_config, **k config.flf_lattice_tool.network.recognizer.feature_extraction.file = "dummy.flow" config.flf_lattice_tool.network.recognizer.lm.scale = 1.0 - arpa_lms = cls.find_arpa_lms( + arpa_lms = lm.find_arpa_lms( config.flf_lattice_tool.network.recognizer.lm, post_config.flf_lattice_tool.network.recognizer.lm if post_config is not None else None, ) @@ -149,6 +132,11 @@ def hash(cls, kwargs): class AdvancedTreeSearchJob(rasr.RasrCommand, Job): + class LmCacheMethod(Enum): + JOINED = "joined" + SEPARATE = "separate" + NONE = "none" + def __init__( self, crp: rasr.CommonRasrParameters, @@ -167,6 +155,7 @@ def __init__( lmgc_mem: float = 12.0, lmgc_alias: Optional[str] = None, lmgc_scorer: Optional[rasr.FeatureScorer] = None, + lm_cache_method: LmCacheMethod = LmCacheMethod.JOINED, model_combination_config: Optional[rasr.RasrConfig] = None, model_combination_post_config: Optional[rasr.RasrConfig] = None, extra_config: Optional[rasr.RasrConfig] = None, @@ -190,6 +179,11 @@ def __init__( :param lmgc_mem: Memory requirement for the AdvancedTreeSearchLmImageAndGlobalCacheJob :param lmgc_alias: Alias for the AdvancedTreeSearchLmImageAndGlobalCacheJob :param lmgc_scorer: Dummy scorer for the AdvancedTreeSearchLmImageAndGlobalCacheJob which is required but unused + :param lm_cache_method: Specifies, how the LM image and the global cache should be created: + JOINED (default) -> automatically create lm images and global cache as output of one job. Note, that this could + create hash issues if e.g. many search jobs with different TDP configuration are started. + SEPARATE -> automatically create lm images and global cache separately. This is to be preferred most of the time. + NONE -> don't create lm images or global cache as part of this job at all :param model_combination_config: Configuration for model combination :param model_combination_post_config: Post config for model combination :param extra_config: Additional Config for recognition @@ -206,7 +200,9 @@ def __init__( self.config, self.post_config, self.lm_gc_job, - ) = AdvancedTreeSearchJob.create_config(**kwargs) + self.gc_job, + self.lm_image_jobs, + ) = self.create_config(**kwargs) self.feature_flow = feature_flow self.exe = self.select_exe(crp.flf_tool_exe, "flf-tool") self.concurrent = crp.concurrent @@ -286,18 +282,17 @@ def create_config( lmgc_mem: float, lmgc_alias: Optional[str], lmgc_scorer: Optional[rasr.FeatureScorer], + lm_cache_method: LmCacheMethod, model_combination_config: Optional[rasr.RasrConfig], model_combination_post_config: Optional[rasr.RasrConfig], extra_config: Optional[rasr.RasrConfig], extra_post_config: Optional[rasr.RasrConfig], **kwargs, ): - lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob( - crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config - ) - if lmgc_alias is not None: - lm_gc.add_alias(lmgc_alias) - lm_gc.rqmt["mem"] = lmgc_mem + def add_lm_config_to_crp(crp: rasr.CommonRasrParameters, lm_config: rasr.RasrConfig): + crp = copy.deepcopy(crp) + crp.language_model_config = lm_config + return crp search_parameters = cls.update_search_parameters(search_parameters) @@ -396,15 +391,51 @@ def create_config( "cache_high" ] - post_config.flf_lattice_tool.global_cache.read_only = True - post_config.flf_lattice_tool.global_cache.file = lm_gc.out_global_cache - - arpa_lms = AdvancedTreeSearchLmImageAndGlobalCacheJob.find_arpa_lms( - config.flf_lattice_tool.network.recognizer.lm, - post_config.flf_lattice_tool.network.recognizer.lm, + # Handle caching of ARPA LMs and maybe build global cache + arpa_lms = lm.find_arpa_lms( + config.flf_lattice_tool.network.recognizer.lm, post_config.flf_lattice_tool.network.recognizer.lm ) - for i, lm_config in enumerate(arpa_lms): - lm_config[1].image = lm_gc.out_lm_images[i + 1] + if lm_cache_method == cls.LmCacheMethod.NONE: + gc_job = None + lm_gc = None + lm_images = None + gc = None + lm_image_jobs = {} + elif lm_cache_method == cls.LmCacheMethod.SEPARATE: + gc_job = BuildGlobalCacheJob(crp, extra_config, extra_post_config) + + lm_image_jobs = { + (i + 1): lm.CreateLmImageJob( + add_lm_config_to_crp(crp, lm_config), extra_config=extra_config, extra_post_config=extra_post_config + ) + for i, (lm_config, _) in enumerate(arpa_lms) + } + + gc = gc_job.out_global_cache + lm_images = {k: v.out_image for k, v in lm_image_jobs.items()} + + lm_gc = None + elif lm_cache_method == cls.LmCacheMethod.JOINED: + lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob( + crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config + ) + if lmgc_alias is not None: + lm_gc.add_alias(lmgc_alias) + lm_gc.rqmt["mem"] = lmgc_mem + + gc = lm_gc.out_global_cache + lm_images = lm_gc.out_lm_images + + gc_job = None + lm_image_jobs = {} + else: + raise TypeError("Argument `lm_cache_method` must be of type `AdvancedTreeSearchJob.LmCacheMethod`") + + post_config.flf_lattice_tool.global_cache.read_only = True + if lm_cache_method != cls.LmCacheMethod.NONE: + post_config.flf_lattice_tool.global_cache.file = gc + for i, (_, lm_post_config) in enumerate(arpa_lms): + lm_post_config.image = lm_images[i + 1] # Remaining Flf-network @@ -438,11 +469,11 @@ def create_config( config._update(extra_config) post_config._update(extra_post_config) - return config, post_config, lm_gc + return config, post_config, lm_gc, gc_job, lm_image_jobs @classmethod def hash(cls, kwargs): - config, post_config, lm_gc = cls.create_config(**kwargs) + config, post_config, *jobs = cls.create_config(**kwargs) return super().hash( { "config": config, @@ -505,6 +536,8 @@ def __init__( self.config, self.post_config, self.lm_gc_job, + self.gc_job, + self.lm_image_jobs, ) = AdvancedTreeSearchWithRescoringJob.create_config(**kwargs) @classmethod @@ -518,7 +551,7 @@ def create_config( rescoring_lookahead_scale, **kwargs, ): - config, post_config, lm_gc_job = super().create_config(**kwargs) + config, *remainder = super().create_config(**kwargs) config.flf_lattice_tool.network.recognizer.links = "rescore" @@ -533,7 +566,7 @@ def create_config( rescore_config.lookahead_scale = rescoring_lookahead_scale rescore_config.lm = rescoring_lm_config - return config, post_config, lm_gc_job + return config, *remainder class BidirectionalAdvancedTreeSearchJob(rasr.RasrCommand, Job): @@ -817,7 +850,7 @@ class BuildGlobalCacheJob(rasr.RasrCommand, Job): Standalone job to create the global-cache for advanced-tree-search """ - def __init__(self, crp, extra_config=None, extra_post_config=None): + def __init__(self, crp, extra_config=None, extra_post_config=None, mem=12): """ :param rasr.CommonRasrParameters crp: common RASR params (required: lexicon, acoustic_model, language_model, recognizer) :param rasr.Configuration extra_config: overlay config that influences the Job's hash @@ -837,7 +870,7 @@ def __init__(self, crp, extra_config=None, extra_post_config=None): self.out_log_file = self.log_file_output_path("build_global_cache", crp, False) self.out_global_cache = self.output_path("global.cache", cached=True) - self.rqmt = {"time": 1, "cpu": 1, "mem": 2} + self.rqmt = {"time": 1, "cpu": 1, "mem": mem} def tasks(self): yield Task("create_files", mini_task=True)