Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b037b71

Browse files
authored
Improve pte loading when given manual params_path (#1178)
1 parent 72d2d20 commit b037b71

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

torchchat/cli/builder.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from torchtune.training import set_default_dtype
3131

32-
from torchchat.model import Model, ModelType
32+
from torchchat.model import Model, ModelArgs, ModelType
3333

3434
from torchchat.model_config.model_config import resolve_model_config
3535
from torchchat.utils.build_utils import (
@@ -527,18 +527,22 @@ def _initialize_model(
527527
)
528528
builder_args.device = "cpu"
529529

530-
# assert (
531-
# quantize is None or quantize == "{ }"
532-
# ), "quantize not valid for exported PTE model. Specify quantization during export."
533-
534-
with measure_time("Time to load model: {time:.02f} seconds"):
535-
model = _load_model(builder_args)
536-
device_sync(device=builder_args.device)
530+
# Resolve ModelArgs for constructing the PTEModel
531+
# If a manual params_path is provided, use that
532+
if builder_args.params_path:
533+
config: ModelArgs = ModelArgs.from_params(builder_args.params_path)
534+
else:
535+
# TODO: Instead of loading the whole model, refactor to call a
536+
# helper that generate just model.config
537+
with measure_time("Time to load model: {time:.02f} seconds"):
538+
model = _load_model(builder_args)
539+
device_sync(device=builder_args.device)
540+
config = model.config
537541

538542
try:
539543
from torchchat.model import PTEModel
540544

541-
model = PTEModel(model.config, builder_args.pte_path)
545+
model = PTEModel(config, builder_args.pte_path)
542546
except Exception:
543547
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
544548
else:

0 commit comments

Comments
 (0)