|
29 | 29 |
|
30 | 30 | from torchtune.training import set_default_dtype |
31 | 31 |
|
32 | | -from torchchat.model import Model, ModelType |
| 32 | +from torchchat.model import Model, ModelArgs, ModelType |
33 | 33 |
|
34 | 34 | from torchchat.model_config.model_config import resolve_model_config |
35 | 35 | from torchchat.utils.build_utils import ( |
@@ -527,18 +527,22 @@ def _initialize_model( |
527 | 527 | ) |
528 | 528 | builder_args.device = "cpu" |
529 | 529 |
|
530 | | - # assert ( |
531 | | - # quantize is None or quantize == "{ }" |
532 | | - # ), "quantize not valid for exported PTE model. Specify quantization during export." |
533 | | - |
534 | | - with measure_time("Time to load model: {time:.02f} seconds"): |
535 | | - model = _load_model(builder_args) |
536 | | - device_sync(device=builder_args.device) |
| 530 | + # Resolve ModelArgs for constructing the PTEModel |
| 531 | + # If a manual params_path is provided, use that |
| 532 | + if builder_args.params_path: |
| 533 | + config: ModelArgs = ModelArgs.from_params(builder_args.params_path) |
| 534 | + else: |
| 535 | + # TODO: Instead of loading the whole model, refactor to call a |
| 536 | + # helper that generate just model.config |
| 537 | + with measure_time("Time to load model: {time:.02f} seconds"): |
| 538 | + model = _load_model(builder_args) |
| 539 | + device_sync(device=builder_args.device) |
| 540 | + config = model.config |
537 | 541 |
|
538 | 542 | try: |
539 | 543 | from torchchat.model import PTEModel |
540 | 544 |
|
541 | | - model = PTEModel(model.config, builder_args.pte_path) |
| 545 | + model = PTEModel(config, builder_args.pte_path) |
542 | 546 | except Exception: |
543 | 547 | raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") |
544 | 548 | else: |
|
0 commit comments