-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] Enable KV cache #1154
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1154
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 1e78f59 with merge base 16b3d64 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| self.max_batch_size = max_batch_size | ||
| for b in self.layers.values(): | ||
| b.attention.kv_cache = KVCache( | ||
| max_batch_size, max_seq_length, self.config.n_local_heads, head_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine, just calling it out as a note to self
Local head_dim is defined as self.config.dim // self.config.n_heads.
In Attention, head_dim is pulled from config.head_dim. Recall that TransformerArgs (config) is also initialized as self.config.dim // self.config.n_heads
| for b in self.layers.values(): | ||
| b.attention.kv_cache = KVCache( | ||
| max_batch_size, max_seq_length, self.config.n_local_heads, head_dim | ||
| # Lower the setup_cache call to the attention module because tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
head_dim as defined on 440 is no longer used, let's remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for enabling, looks good!
Previously KV cache was turned off in distributed cases, enabling it now.
The cache size is related to TP degree -- because the heads are divided across ranks.
For this information to be available, we lowered KV cache instantiation from root level to attention layer (which actually owns the cache and knows how itself is being TP'lized).