Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Sep 17, 2024

Previously KV cache was turned off in distributed cases, enabling it now.

The cache size is related to TP degree -- because the heads are divided across ranks.

For this information to be available, we lowered KV cache instantiation from root level to attention layer (which actually owns the cache and knows how itself is being TP'lized).

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1154

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 1e78f59 with merge base 16b3d64 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 17, 2024
self.max_batch_size = max_batch_size
for b in self.layers.values():
b.attention.kv_cache = KVCache(
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, just calling it out as a note to self

Local head_dim is defined as self.config.dim // self.config.n_heads.

In Attention, head_dim is pulled from config.head_dim. Recall that TransformerArgs (config) is also initialized as self.config.dim // self.config.n_heads

for b in self.layers.values():
b.attention.kv_cache = KVCache(
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
# Lower the setup_cache call to the attention module because tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head_dim as defined on 440 is no longer used, let's remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thanks!

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for enabling, looks good!

@kwen2501 kwen2501 merged commit 0bc64c2 into main Sep 17, 2024
49 of 51 checks passed
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants