|
10 | 10 |
|
11 | 11 | Path = setup_path(__package__) |
12 | 12 |
|
| 13 | +import copy |
13 | 14 | import math |
14 | 15 | import os |
15 | 16 | import shutil |
@@ -167,6 +168,7 @@ def __init__( |
167 | 168 | lmgc_mem: float = 12.0, |
168 | 169 | lmgc_alias: Optional[str] = None, |
169 | 170 | lmgc_scorer: Optional[rasr.FeatureScorer] = None, |
| 171 | + separate_lmi_gc_generation: bool = False, |
170 | 172 | model_combination_config: Optional[rasr.RasrConfig] = None, |
171 | 173 | model_combination_post_config: Optional[rasr.RasrConfig] = None, |
172 | 174 | extra_config: Optional[rasr.RasrConfig] = None, |
@@ -286,18 +288,40 @@ def create_config( |
286 | 288 | lmgc_mem: float, |
287 | 289 | lmgc_alias: Optional[str], |
288 | 290 | lmgc_scorer: Optional[rasr.FeatureScorer], |
| 291 | + separate_lmi_gc_generation: bool, |
289 | 292 | model_combination_config: Optional[rasr.RasrConfig], |
290 | 293 | model_combination_post_config: Optional[rasr.RasrConfig], |
291 | 294 | extra_config: Optional[rasr.RasrConfig], |
292 | 295 | extra_post_config: Optional[rasr.RasrConfig], |
293 | 296 | **kwargs, |
294 | 297 | ): |
295 | | - lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob( |
296 | | - crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config |
297 | | - ) |
298 | | - if lmgc_alias is not None: |
299 | | - lm_gc.add_alias(lmgc_alias) |
300 | | - lm_gc.rqmt["mem"] = lmgc_mem |
| 298 | + def specialize_lm_config(crp, lm_config): |
| 299 | + crp = copy.deepcopy(crp) |
| 300 | + crp.language_model = lm_config |
| 301 | + return crp |
| 302 | + |
| 303 | + if separate_lmi_gc_generation: |
| 304 | + gc = BuildGlobalCacheJob(crp, extra_config, extra_post_config).out_global_cache |
| 305 | + |
| 306 | + arpa_lms = AdvancedTreeSearchLmImageAndGlobalCacheJob.find_arpa_lms( |
| 307 | + crp.language_model, post_config.lm if post_config is not None else None |
| 308 | + ) |
| 309 | + lm_images = { |
| 310 | + (i + 1): lm.CreateLmImageJob( |
| 311 | + specialize_lm_config(crp, lm), extra_config=extra_config, extra_post_config=extra_post_config |
| 312 | + ).out_lm |
| 313 | + for i, lm in enumerate(arpa_lms) |
| 314 | + } |
| 315 | + else: |
| 316 | + lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob( |
| 317 | + crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config |
| 318 | + ) |
| 319 | + if lmgc_alias is not None: |
| 320 | + lm_gc.add_alias(lmgc_alias) |
| 321 | + lm_gc.rqmt["mem"] = lmgc_mem |
| 322 | + |
| 323 | + gc = lm_gc.out_global_cache |
| 324 | + lm_images = lm_gc.out_lm_images |
301 | 325 |
|
302 | 326 | search_parameters = cls.update_search_parameters(search_parameters) |
303 | 327 |
|
@@ -397,14 +421,14 @@ def create_config( |
397 | 421 | ] |
398 | 422 |
|
399 | 423 | post_config.flf_lattice_tool.global_cache.read_only = True |
400 | | - post_config.flf_lattice_tool.global_cache.file = lm_gc.out_global_cache |
| 424 | + post_config.flf_lattice_tool.global_cache.file = gc |
401 | 425 |
|
402 | 426 | arpa_lms = AdvancedTreeSearchLmImageAndGlobalCacheJob.find_arpa_lms( |
403 | 427 | config.flf_lattice_tool.network.recognizer.lm, |
404 | 428 | post_config.flf_lattice_tool.network.recognizer.lm, |
405 | 429 | ) |
406 | 430 | for i, lm_config in enumerate(arpa_lms): |
407 | | - lm_config[1].image = lm_gc.out_lm_images[i + 1] |
| 431 | + lm_config[1].image = lm_images[i + 1] |
408 | 432 |
|
409 | 433 | # Remaining Flf-network |
410 | 434 |
|
@@ -438,11 +462,11 @@ def create_config( |
438 | 462 | config._update(extra_config) |
439 | 463 | post_config._update(extra_post_config) |
440 | 464 |
|
441 | | - return config, post_config, lm_gc |
| 465 | + return config, post_config |
442 | 466 |
|
443 | 467 | @classmethod |
444 | 468 | def hash(cls, kwargs): |
445 | | - config, post_config, lm_gc = cls.create_config(**kwargs) |
| 469 | + config, post_config = cls.create_config(**kwargs) |
446 | 470 | return super().hash( |
447 | 471 | { |
448 | 472 | "config": config, |
|
0 commit comments