diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 02b1545d0..bcb737202 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -79,19 +79,16 @@ def __post_init__(self): if self.dso_path and self.pte_path: raise RuntimeError("specify either DSO path or PTE path, but not both") - if self.checkpoint_path and (self.dso_path or self.pte_path): - print( - "Warning: checkpoint path ignored because an exported DSO or PTE path specified" - ) - if self.checkpoint_dir and (self.dso_path or self.pte_path): - print( - "Warning: checkpoint dir ignored because an exported DSO or PTE path specified" - ) - if self.gguf_path and (self.dso_path or self.pte_path): - print( - "Warning: GGUF path ignored because an exported DSO or PTE path specified" - ) - if not (self.dso_path) and not (self.pte_path): + if self.dso_path or self.pte_path: + ignored_params = [ + (self.checkpoint_path, "checkpoint path"), + (self.checkpoint_dir, "checkpoint dir"), + (self.gguf_path, "GGUF path"), + ] + for param, param_msg in ignored_params: + if param: + print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified") + else: self.prefill_possible = True @classmethod @@ -446,7 +443,7 @@ def _maybe_init_distributed( return world_mesh, parallel_dims -def _maybe_parellelize_model( +def _maybe_parallelize_model( model: nn.Module, builder_args: BuilderArgs, world_mesh: DeviceMesh, @@ -486,7 +483,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: model = _init_model_on_meta_device(builder_args) else: model = _load_model_default(builder_args) - model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims) + model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval()