Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit eb68319

Browse files
author
Martin Yuan
committed
[Flamingo] Fix the memory usage of 2x checkpoint size after loading
1 parent e27e162 commit eb68319

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

torchchat/cli/builder.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737
from torchchat.model import Model, ModelType
3838

39+
from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings
40+
3941
from torchchat.model_config.model_config import resolve_model_config
4042
from torchchat.utils.build_utils import (
4143
device_sync,
@@ -387,9 +389,23 @@ def _load_model_default(builder_args, only_config=False):
387389
with set_default_dtype(builder_args.precision), torch.device(
388390
builder_args.device
389391
):
390-
model = Model.from_params(builder_args.params_path)
392+
# It doubles the model size the memory, with redundancies of the initialized weights.
393+
# model = Model.from_params(builder_args.params_path)
394+
395+
# Buffers in rotary embedding are not included in the checkpoint.
396+
# Instead, they are calculated in initialization. Since buffers on meta device
397+
# does not host any actual values, need to reinitialize them in the actual
398+
# device. Only do those buffer initialization, without initializing the entire
399+
# model.
400+
decoder_config = model.config.transformer_args['decoder']
401+
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
402+
max_seq_len = decoder_config['max_seq_len']
403+
rope_base = decoder_config['rope_base']
404+
for submodule in model.modules():
405+
if isinstance(submodule, RotaryPositionalEmbeddings):
406+
submodule.__init__(head_dim, max_seq_len, rope_base)
391407
state_dict = flamingo_meta_to_tune(checkpoint)
392-
model.model.load_state_dict(state_dict)
408+
model.model.load_state_dict(state_dict, assign=True, strict=False)
393409
else:
394410
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
395411
model.load_state_dict(checkpoint, assign=True, strict=True)
@@ -472,7 +488,6 @@ def _load_model(builder_args, only_config=False):
472488
model = model.to(device=builder_args.device, dtype=builder_args.precision)
473489
return model.eval()
474490

475-
476491
def _initialize_model(
477492
builder_args,
478493
quantize,

0 commit comments

Comments
 (0)