Skip to content

Commit 620a042

Browse files
committed
feat: Add option to generate LM image and GC via two separate jobs
Closes #430
1 parent f9a9f39 commit 620a042

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

recognition/advanced_tree_search.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
Path = setup_path(__package__)
1212

13+
import copy
1314
import math
1415
import os
1516
import shutil
@@ -167,6 +168,7 @@ def __init__(
167168
lmgc_mem: float = 12.0,
168169
lmgc_alias: Optional[str] = None,
169170
lmgc_scorer: Optional[rasr.FeatureScorer] = None,
171+
separate_lmi_gc_generation: bool = False,
170172
model_combination_config: Optional[rasr.RasrConfig] = None,
171173
model_combination_post_config: Optional[rasr.RasrConfig] = None,
172174
extra_config: Optional[rasr.RasrConfig] = None,
@@ -286,18 +288,40 @@ def create_config(
286288
lmgc_mem: float,
287289
lmgc_alias: Optional[str],
288290
lmgc_scorer: Optional[rasr.FeatureScorer],
291+
separate_lmi_gc_generation: bool,
289292
model_combination_config: Optional[rasr.RasrConfig],
290293
model_combination_post_config: Optional[rasr.RasrConfig],
291294
extra_config: Optional[rasr.RasrConfig],
292295
extra_post_config: Optional[rasr.RasrConfig],
293296
**kwargs,
294297
):
295-
lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob(
296-
crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config
297-
)
298-
if lmgc_alias is not None:
299-
lm_gc.add_alias(lmgc_alias)
300-
lm_gc.rqmt["mem"] = lmgc_mem
298+
def specialize_lm_config(crp, lm_config):
299+
crp = copy.deepcopy(crp)
300+
crp.language_model = lm_config
301+
return crp
302+
303+
if separate_lmi_gc_generation:
304+
gc = BuildGlobalCacheJob(crp, extra_config, extra_post_config).out_global_cache
305+
306+
arpa_lms = AdvancedTreeSearchLmImageAndGlobalCacheJob.find_arpa_lms(
307+
crp.language_model, post_config.lm if post_config is not None else None
308+
)
309+
lm_images = {
310+
(i + 1): lm.CreateLmImageJob(
311+
specialize_lm_config(crp, lm), extra_config=extra_config, extra_post_config=extra_post_config
312+
).out_lm
313+
for i, lm in enumerate(arpa_lms)
314+
}
315+
else:
316+
lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob(
317+
crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config
318+
)
319+
if lmgc_alias is not None:
320+
lm_gc.add_alias(lmgc_alias)
321+
lm_gc.rqmt["mem"] = lmgc_mem
322+
323+
gc = lm_gc.out_global_cache
324+
lm_images = lm_gc.out_lm_images
301325

302326
search_parameters = cls.update_search_parameters(search_parameters)
303327

@@ -397,14 +421,14 @@ def create_config(
397421
]
398422

399423
post_config.flf_lattice_tool.global_cache.read_only = True
400-
post_config.flf_lattice_tool.global_cache.file = lm_gc.out_global_cache
424+
post_config.flf_lattice_tool.global_cache.file = gc
401425

402426
arpa_lms = AdvancedTreeSearchLmImageAndGlobalCacheJob.find_arpa_lms(
403427
config.flf_lattice_tool.network.recognizer.lm,
404428
post_config.flf_lattice_tool.network.recognizer.lm,
405429
)
406430
for i, lm_config in enumerate(arpa_lms):
407-
lm_config[1].image = lm_gc.out_lm_images[i + 1]
431+
lm_config[1].image = lm_images[i + 1]
408432

409433
# Remaining Flf-network
410434

@@ -438,11 +462,11 @@ def create_config(
438462
config._update(extra_config)
439463
post_config._update(extra_post_config)
440464

441-
return config, post_config, lm_gc
465+
return config, post_config
442466

443467
@classmethod
444468
def hash(cls, kwargs):
445-
config, post_config, lm_gc = cls.create_config(**kwargs)
469+
config, post_config = cls.create_config(**kwargs)
446470
return super().hash(
447471
{
448472
"config": config,

0 commit comments

Comments
 (0)