diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 511cf1f35..504669563 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -66,6 +66,8 @@ class BuilderArgs: dynamic_shapes: bool = False max_seq_length: Optional[int] = None + state_dict_path: Optional[Union[Path, str]] = None + def __post_init__(self): if self.device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -185,6 +187,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), + state_dict_path=args.state_dict_path, ) @classmethod @@ -579,26 +582,47 @@ def _initialize_model( model = _load_model(builder_args) device_sync(device=builder_args.device) - if quantize: - print(f"Quantizing the model with: {quantize}") - with measure_time("Time to quantize model: {time:.02f} seconds"): - quantize_model( - model, - builder_args.device, - quantize, - tokenizer, - support_tensor_subclass, - ) - device_sync(device=builder_args.device) + state_dict_path = builder_args.state_dict_path + state_dict_exists: bool = state_dict_path and os.path.isfile(state_dict_path) + if quantize or state_dict_exists: - if builder_args.setup_caches: - with torch.device(builder_args.device): - model.setup_caches( - max_batch_size=1, - max_seq_length=max_seq_length - or model.text_transformer_args.max_seq_length, + if quantize and state_dict_exists: + print( + "WARNING: Both a state_dict and quantize arg were provided; Ignoring quantize arg" ) + if state_dict_exists: + with measure_time("Time to load quantized state: {time:.02f} seconds"): + print(f"Loading the model_state in: {state_dict_path}") + model.load_state_dict(state_dict_path) + device_sync(device=builder_args.device) + else: + with measure_time("Time to quantize model: {time:.02f} seconds"): + print(f"Quantizing the model with: {quantize}") + quantize_model( + model, + builder_args.device, + quantize, + tokenizer, + support_tensor_subclass, + ) + device_sync(device=builder_args.device) + + if state_dict_path: + with measure_time( + "Time to save quantized state: {time:.02f} seconds" + ): + print(f"Saving the quantized state dict") + torch.save(model.state_dict(), state_dict_path) + + if builder_args.setup_caches: + with torch.device(builder_args.device): + model.setup_caches( + max_batch_size=1, + max_seq_length=max_seq_length + or model.text_transformer_args.max_seq_length, + ) + model.to(dtype=builder_args.precision) print("-----------------------------------------------------------") diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index bc41d56ec..186737434 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -148,6 +148,12 @@ def _add_model_config_args(parser, verb: str) -> None: help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.", ) + model_config_parser.add_argument( + "--state-dict-path", + type=str, + default=None, + help="Model state dict to load (if path exists) or write out to (if path doesn't exist). Supercedes --quantize arg.", + ) model_config_parser.add_argument( "--dtype", default="fast",