-
Notifications
You must be signed in to change notification settings - Fork 755
Yield per-document RoPE position ids from dataset #2560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7baabbd
1d27557
d2433d5
4e6888b
13c1d03
338bb8f
3e23b47
9cf6a39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |||||
| from typing import Literal | ||||||
|
|
||||||
| import torch | ||||||
| from torch.distributed.tensor import DTensor, Replicate, Shard | ||||||
|
|
||||||
| from torchtitan.protocols.module import Module | ||||||
|
|
||||||
|
|
@@ -289,6 +290,43 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |||||
| return torch.cat((-x2, x1), dim=-1) | ||||||
|
|
||||||
|
|
||||||
| def _maybe_wrap_positions( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sg, please leave a TODO: in full DTensor rewrite, we should make cc @fegin |
||||||
| positions: torch.Tensor | None, | ||||||
| x: torch.Tensor, | ||||||
| ) -> torch.Tensor | None: | ||||||
| """Wrap positions as a DTensor deriving mesh and placements from x (xq/xk). | ||||||
|
|
||||||
| TODO: In a full DTensor rewrite, positions should be made a DTensor | ||||||
| in/right after dataloading, together with inputs and labels. | ||||||
|
|
||||||
| When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS), | ||||||
| x is a DTensor but positions is a plain tensor. The downstream | ||||||
| torch.gather requires both operands to be the same type. | ||||||
|
|
||||||
| Positions (bsz, seqlen) has fewer dimensions than x (bsz, seqlen, | ||||||
| n_heads, head_dim), so we only preserve Shard placements for shared | ||||||
| dimensions. Shard dims beyond positions' rank (e.g. Shard(2) for TP | ||||||
| on heads) become Replicate. | ||||||
| """ | ||||||
| if ( | ||||||
| positions is not None | ||||||
| and isinstance(x, DTensor) | ||||||
| and not isinstance(positions, DTensor) | ||||||
| ): | ||||||
| ndim = positions.ndim | ||||||
| placements = tuple( | ||||||
| p if not isinstance(p, Shard) or p.dim < ndim else Replicate() | ||||||
| for p in x.placements | ||||||
| ) | ||||||
| positions = DTensor.from_local( | ||||||
| positions, | ||||||
| x.device_mesh, | ||||||
| placements, | ||||||
| run_check=False, | ||||||
| ) | ||||||
| return positions | ||||||
|
|
||||||
|
|
||||||
| # TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex | ||||||
| def apply_rotary_emb_complex( | ||||||
| xq: torch.Tensor, | ||||||
|
|
@@ -304,6 +342,7 @@ def apply_rotary_emb_complex( | |||||
| freqs_cis: (max_seqlen, head_dim // 2) complex | ||||||
| positions: optional position indices | ||||||
| """ | ||||||
| positions = _maybe_wrap_positions(positions, xq) | ||||||
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||||||
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||||||
| freqs_cis = _reshape_for_broadcast_complex(freqs_cis, xq_, positions) | ||||||
|
|
@@ -324,6 +363,7 @@ def apply_rotary_emb_single_complex( | |||||
| freqs_cis: (max_seqlen, head_dim // 2) complex | ||||||
| positions: optional position indices | ||||||
| """ | ||||||
| positions = _maybe_wrap_positions(positions, x) | ||||||
| dtype = x.dtype | ||||||
| x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) | ||||||
| freqs_cis = _reshape_for_broadcast_complex(freqs_cis, x, positions) | ||||||
|
|
@@ -345,6 +385,7 @@ def apply_rotary_emb_cos_sin( | |||||
| rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated | ||||||
| positions: optional position indices | ||||||
| """ | ||||||
| positions = _maybe_wrap_positions(positions, xq) | ||||||
| head_dim = xq.shape[-1] | ||||||
| rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions) | ||||||
| cos = rope_cache[..., :head_dim].to(device=xq.device) | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.