Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f0a03a7

Browse files
iseeyuanMartin Yuan
andauthored
[Flamingo] Fix the memory usage of 2x checkpoint size after loading (#1201)
Co-authored-by: Martin Yuan <[email protected]>
1 parent 6fd90bc commit f0a03a7

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

torchchat/cli/builder.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
from torchchat.model import Model, ModelArgs, ModelType
3333

34+
from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings
35+
3436
from torchchat.model_config.model_config import resolve_model_config
3537
from torchchat.utils.build_utils import (
3638
device_sync,
@@ -387,9 +389,23 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
387389
with set_default_dtype(builder_args.precision), torch.device(
388390
builder_args.device
389391
):
390-
model = Model.from_params(builder_args.params_path)
392+
# It doubles the model size the memory, with redundancies of the initialized weights.
393+
# model = Model.from_params(builder_args.params_path)
394+
395+
# Buffers in rotary embedding are not included in the checkpoint.
396+
# Instead, they are calculated in initialization. Since buffers on meta device
397+
# does not host any actual values, need to reinitialize them in the actual
398+
# device. Only do those buffer initialization, without initializing the entire
399+
# model.
400+
decoder_config = model.config.transformer_args['decoder']
401+
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
402+
max_seq_len = decoder_config['max_seq_len']
403+
rope_base = decoder_config['rope_base']
404+
for submodule in model.modules():
405+
if isinstance(submodule, RotaryPositionalEmbeddings):
406+
submodule.__init__(head_dim, max_seq_len, rope_base)
391407
state_dict = flamingo_meta_to_tune(checkpoint)
392-
model.model.load_state_dict(state_dict)
408+
model.model.load_state_dict(state_dict, assign=True, strict=False)
393409
else:
394410
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
395411
model.load_state_dict(checkpoint, assign=True, strict=True)
@@ -472,7 +488,6 @@ def _load_model(builder_args: BuilderArgs) -> Model:
472488
model = model.to(device=builder_args.device, dtype=builder_args.precision)
473489
return model.eval()
474490

475-
476491
def _initialize_model(
477492
builder_args: BuilderArgs,
478493
quantize,

0 commit comments

Comments
 (0)