|  | 
| 29 | 29 | 
 | 
| 30 | 30 | from torchtune.training import set_default_dtype | 
| 31 | 31 | 
 | 
| 32 |  | -from torchchat.model import Model, ModelType | 
|  | 32 | +from torchchat.model import Model, ModelArgs, ModelType | 
| 33 | 33 | 
 | 
| 34 | 34 | from torchchat.model_config.model_config import resolve_model_config | 
| 35 | 35 | from torchchat.utils.build_utils import ( | 
| @@ -527,18 +527,22 @@ def _initialize_model( | 
| 527 | 527 |             ) | 
| 528 | 528 |             builder_args.device = "cpu" | 
| 529 | 529 | 
 | 
| 530 |  | -        # assert ( | 
| 531 |  | -        #     quantize is None or quantize == "{ }" | 
| 532 |  | -        # ), "quantize not valid for exported PTE model. Specify quantization during export." | 
| 533 |  | - | 
| 534 |  | -        with measure_time("Time to load model: {time:.02f} seconds"): | 
| 535 |  | -            model = _load_model(builder_args) | 
| 536 |  | -            device_sync(device=builder_args.device) | 
|  | 530 | +        # Resolve ModelArgs for constructing the PTEModel | 
|  | 531 | +        # If a manual params_path is provided, use that | 
|  | 532 | +        if builder_args.params_path: | 
|  | 533 | +            config: ModelArgs = ModelArgs.from_params(builder_args.params_path) | 
|  | 534 | +        else: | 
|  | 535 | +            # TODO: Instead of loading the whole model, refactor to call a | 
|  | 536 | +            # helper that generate just model.config | 
|  | 537 | +            with measure_time("Time to load model: {time:.02f} seconds"): | 
|  | 538 | +                model = _load_model(builder_args) | 
|  | 539 | +                device_sync(device=builder_args.device) | 
|  | 540 | +                config = model.config | 
| 537 | 541 | 
 | 
| 538 | 542 |         try: | 
| 539 | 543 |             from torchchat.model import PTEModel | 
| 540 | 544 | 
 | 
| 541 |  | -            model = PTEModel(model.config, builder_args.pte_path) | 
|  | 545 | +            model = PTEModel(config, builder_args.pte_path) | 
| 542 | 546 |         except Exception: | 
| 543 | 547 |             raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") | 
| 544 | 548 |     else: | 
|  | 
0 commit comments