Skip to content
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
620a042
feat: Add option to generate LM image and GC via two separate jobs
NeoLegends Aug 21, 2023
a68265c
chore: Document parameter
NeoLegends Aug 21, 2023
b74c654
fix: Always assign the (possibly to None) lm_gc property
NeoLegends Aug 21, 2023
79cbd2a
fix bug, assign jobs to class if possible
NeoLegends Aug 22, 2023
78f6bed
refactor find_arpa_lms into standalone function
NeoLegends Aug 29, 2023
291b734
Merge branch 'main' into feat/separate-lmi-gc-generation
NeoLegends Aug 29, 2023
d209e8f
fix bugs from trial runs
NeoLegends Aug 29, 2023
02d7088
Re-enable lm-util
NeoLegends Aug 29, 2023
6a236c8
more mem for LM + GC jobs
NeoLegends Aug 29, 2023
a61b03f
make mem configurable
NeoLegends Aug 29, 2023
008c9dc
even more mem
NeoLegends Aug 29, 2023
07e1136
enable split behavior by default, document hash implications
NeoLegends Aug 29, 2023
cdc791a
disable flag by default
NeoLegends Aug 30, 2023
ab010c5
Rename flag to be more clear
NeoLegends Sep 11, 2023
5a8ec3d
rename local function
NeoLegends Sep 11, 2023
ed001c4
fix wording
NeoLegends Nov 6, 2023
115b160
Merge branch 'main' into feat/separate-lmi-gc-generation
NeoLegends Nov 9, 2023
49315ea
Merge branch 'main' into feat/separate-lmi-gc-generation
Jul 15, 2025
0182ad6
introduce enum
Jul 30, 2025
f3199fb
Apply suggestions from code review
DanEnergetics Jul 30, 2025
f784e87
Merge branch 'feat/separate-lmi-gc-generation' of github.com:rwth-i6/…
Jul 30, 2025
f02c2c9
change util function signature
Jul 30, 2025
9b61538
more reviewer comments
Jul 30, 2025
d2507d0
ruff formatting
Jul 30, 2025
a5a6af9
adjust rescoring job
Jul 30, 2025
f74c52f
fix parameter name typo
Jul 30, 2025
0f82f3f
fix empty post config
Jul 30, 2025
ffe3681
postpone arpa discovery -> should fix hash test
Aug 6, 2025
ef5e8ce
ruff
Aug 6, 2025
a2292d0
reviewer comments
Aug 15, 2025
7d1bcc9
more reviewer comments
Aug 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .reverse_arpa import *
from .vocabulary import *
from .srilm import *
from .util import *
2 changes: 1 addition & 1 deletion lm/lm_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
extra_config=None,
extra_post_config=None,
encoding="utf-8",
mem=2,
mem=12,
):
kwargs = locals()
del kwargs["self"]
Expand Down
26 changes: 26 additions & 0 deletions lm/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import List, Optional, Tuple

import i6_core.rasr as rasr


def _has_image(config: rasr.RasrConfig, post_config: Optional[rasr.RasrConfig]):
res = config._get("image") is not None
res = res or (post_config is not None and post_config._get("image") is not None)
return res


def find_arpa_lms(
lm_config: rasr.RasrConfig, lm_post_config: Optional[rasr.RasrConfig] = None
) -> List[Tuple[rasr.RasrConfig, rasr.RasrConfig]]:
result = []

if lm_config.type == "ARPA":
if not _has_image(lm_config, lm_post_config):
result.append((lm_config, lm_post_config))
elif lm_config.type == "combine":
for i in range(1, lm_config.num_lms + 1):
sub_lm_config = lm_config[f"lm-{i}"]
sub_lm_post_config = lm_post_config[f"lm-{i}"] if lm_post_config is not None else None
result += find_arpa_lms(sub_lm_config, sub_lm_post_config)

return result
2 changes: 1 addition & 1 deletion rasr/crp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def set_executables(self, rasr_binary_path, rasr_arch="linux-x86_64-standard"):
self.flf_tool_exe = rasr_binary_path.join_right(f"flf-tool.{rasr_arch}")
self.kws_tool_exe = None # does not exist
self.lattice_processor_exe = rasr_binary_path.join_right(f"lattice-processor.{rasr_arch}")
self.lm_util_exe = None # does not exist
self.lm_util_exe = rasr_binary_path.join_right(f"lm-util.{rasr_arch}")
self.nn_trainer_exe = rasr_binary_path.join_right(f"nn-trainer.{rasr_arch}")
self.speech_recognizer_exe = rasr_binary_path.join_right(f"speech-recognizer.{rasr_arch}")

Expand Down
114 changes: 73 additions & 41 deletions recognition/advanced_tree_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

Path = setup_path(__package__)

import copy
from enum import Enum
import math
import os
import shutil
Expand Down Expand Up @@ -69,25 +71,6 @@ def run(self):
def cleanup_before_run(self, cmd, retry, *args):
util.backup_if_exists("lm_and_state_tree.log")

@classmethod
def find_arpa_lms(cls, lm_config, lm_post_config=None):
result = []

def has_image(c, pc):
res = c._get("image") is not None
res = res or (pc is not None and pc._get("image") is not None)
return res

if lm_config.type == "ARPA":
if not has_image(lm_config, lm_post_config):
result.append((lm_config, lm_post_config))
elif lm_config.type == "combine":
for i in range(1, lm_config.num_lms + 1):
sub_lm_config = lm_config["lm-%d" % i]
sub_lm_post_config = lm_post_config["lm-%d" % i] if lm_post_config is not None else None
result += cls.find_arpa_lms(sub_lm_config, sub_lm_post_config)
return result

@classmethod
def create_config(cls, crp, feature_scorer, extra_config, extra_post_config, **kwargs):
config, post_config = rasr.build_config_from_mapping(
Expand Down Expand Up @@ -117,7 +100,7 @@ def create_config(cls, crp, feature_scorer, extra_config, extra_post_config, **k
config.flf_lattice_tool.network.recognizer.feature_extraction.file = "dummy.flow"
config.flf_lattice_tool.network.recognizer.lm.scale = 1.0

arpa_lms = cls.find_arpa_lms(
arpa_lms = lm.find_arpa_lms(
config.flf_lattice_tool.network.recognizer.lm,
post_config.flf_lattice_tool.network.recognizer.lm if post_config is not None else None,
)
Expand Down Expand Up @@ -149,6 +132,11 @@ def hash(cls, kwargs):


class AdvancedTreeSearchJob(rasr.RasrCommand, Job):
class LmCacheMethod(Enum):
JOINED = "joined"
SEPARATE = "separate"
NONE = "none"

def __init__(
self,
crp: rasr.CommonRasrParameters,
Expand All @@ -167,6 +155,7 @@ def __init__(
lmgc_mem: float = 12.0,
lmgc_alias: Optional[str] = None,
lmgc_scorer: Optional[rasr.FeatureScorer] = None,
lm_cache_method: LmCacheMethod = LmCacheMethod.JOINED,
model_combination_config: Optional[rasr.RasrConfig] = None,
model_combination_post_config: Optional[rasr.RasrConfig] = None,
extra_config: Optional[rasr.RasrConfig] = None,
Expand All @@ -190,6 +179,10 @@ def __init__(
:param lmgc_mem: Memory requirement for the AdvancedTreeSearchLmImageAndGlobalCacheJob
:param lmgc_alias: Alias for the AdvancedTreeSearchLmImageAndGlobalCacheJob
:param lmgc_scorer: Dummy scorer for the AdvancedTreeSearchLmImageAndGlobalCacheJob which is required but unused
:param lm_cache_method: Specifies, how the LM image and the global cache should be created:
JOINED (default) -> automatically create lm images and global cache as output of one job
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: consider informing about the tradeoff between the different options here, i.e. that the default has hash issues.

SEPARATE -> automatically create lm images and global cache separately
NONE -> don't create lm images or global cache as part of this job at all
:param model_combination_config: Configuration for model combination
:param model_combination_post_config: Post config for model combination
:param extra_config: Additional Config for recognition
Expand All @@ -206,7 +199,9 @@ def __init__(
self.config,
self.post_config,
self.lm_gc_job,
) = AdvancedTreeSearchJob.create_config(**kwargs)
self.gc_job,
self.lm_image_jobs,
) = self.create_config(**kwargs)
self.feature_flow = feature_flow
self.exe = self.select_exe(crp.flf_tool_exe, "flf-tool")
self.concurrent = crp.concurrent
Expand Down Expand Up @@ -286,18 +281,17 @@ def create_config(
lmgc_mem: float,
lmgc_alias: Optional[str],
lmgc_scorer: Optional[rasr.FeatureScorer],
lm_cache_method: LmCacheMethod,
model_combination_config: Optional[rasr.RasrConfig],
model_combination_post_config: Optional[rasr.RasrConfig],
extra_config: Optional[rasr.RasrConfig],
extra_post_config: Optional[rasr.RasrConfig],
**kwargs,
):
lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob(
crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config
)
if lmgc_alias is not None:
lm_gc.add_alias(lmgc_alias)
lm_gc.rqmt["mem"] = lmgc_mem
def add_lm_config_to_crp(crp: rasr.CommonRasrParameters, lm_config: rasr.RasrConfig):
crp = copy.deepcopy(crp)
crp.language_model_config = lm_config
return crp

search_parameters = cls.update_search_parameters(search_parameters)

Expand Down Expand Up @@ -396,15 +390,51 @@ def create_config(
"cache_high"
]

post_config.flf_lattice_tool.global_cache.read_only = True
post_config.flf_lattice_tool.global_cache.file = lm_gc.out_global_cache

arpa_lms = AdvancedTreeSearchLmImageAndGlobalCacheJob.find_arpa_lms(
config.flf_lattice_tool.network.recognizer.lm,
post_config.flf_lattice_tool.network.recognizer.lm,
# Handle caching of ARPA LMs and maybe build global cache
arpa_lms = lm.find_arpa_lms(
config.flf_lattice_tool.network.recognizer.lm, post_config.flf_lattice_tool.network.recognizer.lm
)
for i, lm_config in enumerate(arpa_lms):
lm_config[1].image = lm_gc.out_lm_images[i + 1]
if lm_cache_method == cls.LmCacheMethod.NONE:
gc_job = None
lm_gc = None
lm_images = None
gc = None
lm_image_jobs = {}
elif lm_cache_method == cls.LmCacheMethod.SEPARATE:
gc_job = BuildGlobalCacheJob(crp, extra_config, extra_post_config)

lm_image_jobs = {
(i + 1): lm.CreateLmImageJob(
add_lm_config_to_crp(crp, lm_config), extra_config=extra_config, extra_post_config=extra_post_config
)
for i, (lm_config, _) in enumerate(arpa_lms)
}

gc = gc_job.out_global_cache
lm_images = {k: v.out_image for k, v in lm_image_jobs.items()}

lm_gc = None
elif lm_cache_method == cls.LmCacheMethod.JOINED:
lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob(
crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config
)
if lmgc_alias is not None:
lm_gc.add_alias(lmgc_alias)
lm_gc.rqmt["mem"] = lmgc_mem

gc = lm_gc.out_global_cache
lm_images = lm_gc.out_lm_images

gc_job = None
lm_image_jobs = {}
else:
raise TypeError("Argument `lm_cache_method` must be of type `AdvancedTreeSearchJob.LmCacheMethod`")

post_config.flf_lattice_tool.global_cache.read_only = True
if lm_cache_method != cls.LmCacheMethod.NONE:
post_config.flf_lattice_tool.global_cache.file = gc
for i, (_, lm_post_config) in enumerate(arpa_lms):
lm_post_config.image = lm_images[i + 1]

# Remaining Flf-network

Expand Down Expand Up @@ -438,11 +468,11 @@ def create_config(
config._update(extra_config)
post_config._update(extra_post_config)

return config, post_config, lm_gc
return config, post_config, lm_gc, gc_job, lm_image_jobs

@classmethod
def hash(cls, kwargs):
config, post_config, lm_gc = cls.create_config(**kwargs)
config, post_config, *jobs = cls.create_config(**kwargs)
return super().hash(
{
"config": config,
Expand Down Expand Up @@ -505,6 +535,8 @@ def __init__(
self.config,
self.post_config,
self.lm_gc_job,
self.gc_job,
self.lm_image_jobs,
) = AdvancedTreeSearchWithRescoringJob.create_config(**kwargs)

@classmethod
Expand All @@ -518,7 +550,7 @@ def create_config(
rescoring_lookahead_scale,
**kwargs,
):
config, post_config, lm_gc_job = super().create_config(**kwargs)
config, *remainder = super().create_config(**kwargs)

config.flf_lattice_tool.network.recognizer.links = "rescore"

Expand All @@ -533,7 +565,7 @@ def create_config(
rescore_config.lookahead_scale = rescoring_lookahead_scale
rescore_config.lm = rescoring_lm_config

return config, post_config, lm_gc_job
return config, *remainder


class BidirectionalAdvancedTreeSearchJob(rasr.RasrCommand, Job):
Expand Down Expand Up @@ -817,7 +849,7 @@ class BuildGlobalCacheJob(rasr.RasrCommand, Job):
Standalone job to create the global-cache for advanced-tree-search
"""

def __init__(self, crp, extra_config=None, extra_post_config=None):
def __init__(self, crp, extra_config=None, extra_post_config=None, mem=12):
"""
:param rasr.CommonRasrParameters crp: common RASR params (required: lexicon, acoustic_model, language_model, recognizer)
:param rasr.Configuration extra_config: overlay config that influences the Job's hash
Expand All @@ -837,7 +869,7 @@ def __init__(self, crp, extra_config=None, extra_post_config=None):
self.out_log_file = self.log_file_output_path("build_global_cache", crp, False)
self.out_global_cache = self.output_path("global.cache", cached=True)

self.rqmt = {"time": 1, "cpu": 1, "mem": 2}
self.rqmt = {"time": 1, "cpu": 1, "mem": mem}

def tasks(self):
yield Task("create_files", mini_task=True)
Expand Down