diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 2933edade..4c3683e73 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -36,6 +36,8 @@ from torchchat.model import Model, ModelType +from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings + from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( device_sync, @@ -387,9 +389,23 @@ def _load_model_default(builder_args, only_config=False): with set_default_dtype(builder_args.precision), torch.device( builder_args.device ): - model = Model.from_params(builder_args.params_path) + # It doubles the model size the memory, with redundancies of the initialized weights. + # model = Model.from_params(builder_args.params_path) + + # Buffers in rotary embedding are not included in the checkpoint. + # Instead, they are calculated in initialization. Since buffers on meta device + # does not host any actual values, need to reinitialize them in the actual + # device. Only do those buffer initialization, without initializing the entire + # model. + decoder_config = model.config.transformer_args['decoder'] + head_dim = decoder_config['embed_dim'] // decoder_config['num_heads'] + max_seq_len = decoder_config['max_seq_len'] + rope_base = decoder_config['rope_base'] + for submodule in model.modules(): + if isinstance(submodule, RotaryPositionalEmbeddings): + submodule.__init__(head_dim, max_seq_len, rope_base) state_dict = flamingo_meta_to_tune(checkpoint) - model.model.load_state_dict(state_dict) + model.model.load_state_dict(state_dict, assign=True, strict=False) else: checkpoint = {"model." + k: v for k, v in checkpoint.items()} model.load_state_dict(checkpoint, assign=True, strict=True) @@ -472,7 +488,6 @@ def _load_model(builder_args, only_config=False): model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() - def _initialize_model( builder_args, quantize,