|  | 
| 36 | 36 | 
 | 
| 37 | 37 | from torchchat.model import Model, ModelType | 
| 38 | 38 | 
 | 
|  | 39 | +from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings | 
|  | 40 | + | 
| 39 | 41 | from torchchat.model_config.model_config import resolve_model_config | 
| 40 | 42 | from torchchat.utils.build_utils import ( | 
| 41 | 43 |     device_sync, | 
| @@ -387,9 +389,23 @@ def _load_model_default(builder_args, only_config=False): | 
| 387 | 389 |         with set_default_dtype(builder_args.precision), torch.device( | 
| 388 | 390 |             builder_args.device | 
| 389 | 391 |         ): | 
| 390 |  | -            model = Model.from_params(builder_args.params_path) | 
|  | 392 | +            # It doubles the model size the memory, with redundancies of the initialized weights. | 
|  | 393 | +            # model = Model.from_params(builder_args.params_path) | 
|  | 394 | + | 
|  | 395 | +            # Buffers in rotary embedding are not included in the checkpoint. | 
|  | 396 | +            # Instead, they are calculated in initialization. Since buffers on meta device | 
|  | 397 | +            # does not host any actual values, need to reinitialize them in the actual | 
|  | 398 | +            # device. Only do those buffer initialization, without initializing the entire | 
|  | 399 | +            # model. | 
|  | 400 | +            decoder_config = model.config.transformer_args['decoder'] | 
|  | 401 | +            head_dim = decoder_config['embed_dim'] // decoder_config['num_heads'] | 
|  | 402 | +            max_seq_len = decoder_config['max_seq_len'] | 
|  | 403 | +            rope_base = decoder_config['rope_base'] | 
|  | 404 | +            for submodule in model.modules(): | 
|  | 405 | +                if isinstance(submodule, RotaryPositionalEmbeddings): | 
|  | 406 | +                    submodule.__init__(head_dim, max_seq_len, rope_base) | 
| 391 | 407 |         state_dict = flamingo_meta_to_tune(checkpoint) | 
| 392 |  | -        model.model.load_state_dict(state_dict) | 
|  | 408 | +        model.model.load_state_dict(state_dict, assign=True, strict=False) | 
| 393 | 409 |     else: | 
| 394 | 410 |         checkpoint = {"model." + k: v for k, v in checkpoint.items()} | 
| 395 | 411 |         model.load_state_dict(checkpoint, assign=True, strict=True) | 
| @@ -472,7 +488,6 @@ def _load_model(builder_args, only_config=False): | 
| 472 | 488 |     model = model.to(device=builder_args.device, dtype=builder_args.precision) | 
| 473 | 489 |     return model.eval() | 
| 474 | 490 | 
 | 
| 475 |  | - | 
| 476 | 491 | def _initialize_model( | 
| 477 | 492 |     builder_args, | 
| 478 | 493 |     quantize, | 
|  | 
0 commit comments