Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@iseeyuan
Copy link
Contributor

@iseeyuan iseeyuan commented Sep 25, 2024

As titled. The Flamingo model was initialized in meta. Then

  • load checkpoint with 1x weight size
  • model is initialized again (probably to populate the buffers in rotary embedding), with another 1x of weigth size.
    So for a model with 21 GB weight size, the memory usage is 42 GB after model loading.

Context:
Buffers in rotary embedding are not included in the checkpoint.
Instead, they are calculated in initialization. Since buffers on meta device
does not host any actual values, need to reinitialize them in the actual
device.

Fix:
Only do those buffer initialization, without initializing the entire model.

Peak memory:
max_seq_len = 128: Before: 43.57 GB after: 28.55 GB
max_seq_len = 8192: Before: 45.47 GB after: 30.5 GB

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1201

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit eb68319 with merge base e27e162 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 25, 2024
@Jack-Khuu Jack-Khuu merged commit f0a03a7 into main Sep 25, 2024
51 checks passed
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants