diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 5dbf48529..b7eb32000 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -29,7 +29,7 @@ from torchtune.training import set_default_dtype -from torchchat.model import Model, ModelType +from torchchat.model import Model, ModelArgs, ModelType from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( @@ -527,18 +527,22 @@ def _initialize_model( ) builder_args.device = "cpu" - # assert ( - # quantize is None or quantize == "{ }" - # ), "quantize not valid for exported PTE model. Specify quantization during export." - - with measure_time("Time to load model: {time:.02f} seconds"): - model = _load_model(builder_args) - device_sync(device=builder_args.device) + # Resolve ModelArgs for constructing the PTEModel + # If a manual params_path is provided, use that + if builder_args.params_path: + config: ModelArgs = ModelArgs.from_params(builder_args.params_path) + else: + # TODO: Instead of loading the whole model, refactor to call a + # helper that generate just model.config + with measure_time("Time to load model: {time:.02f} seconds"): + model = _load_model(builder_args) + device_sync(device=builder_args.device) + config = model.config try: from torchchat.model import PTEModel - model = PTEModel(model.config, builder_args.pte_path) + model = PTEModel(config, builder_args.pte_path) except Exception: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") else: