-
Notifications
You must be signed in to change notification settings - Fork 755
Yield per-document RoPE position ids from dataset #2560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
7baabbd
1d27557
d2433d5
4e6888b
13c1d03
338bb8f
3e23b47
9cf6a39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |||||
| from typing import Literal | ||||||
|
|
||||||
| import torch | ||||||
| from torch.distributed.tensor import DTensor, Replicate | ||||||
|
|
||||||
| from torchtitan.protocols.module import Module | ||||||
|
|
||||||
|
|
@@ -289,6 +290,35 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |||||
| return torch.cat((-x2, x1), dim=-1) | ||||||
|
|
||||||
|
|
||||||
| def _maybe_wrap_positions( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sg, please leave a TODO: in full DTensor rewrite, we should make cc @fegin |
||||||
| positions: torch.Tensor | None, | ||||||
| freqs_cis: torch.Tensor, | ||||||
| ) -> torch.Tensor | None: | ||||||
| """Wrap positions as a DTensor if freqs_cis is a DTensor. | ||||||
|
|
||||||
| When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS), | ||||||
| freqs_cis is a DTensor (Replicate) but positions is a plain tensor. | ||||||
| The downstream torch.gather requires both operands to be the same type. | ||||||
| Since positions (int64 indices) has no gradient, grad_placements is | ||||||
| not needed. | ||||||
| """ | ||||||
| if ( | ||||||
| positions is not None | ||||||
| and isinstance(freqs_cis, DTensor) | ||||||
| and not isinstance(positions, DTensor) | ||||||
| ): | ||||||
| assert all( | ||||||
| isinstance(p, Replicate) for p in freqs_cis.placements | ||||||
| ), f"Expected Replicate placements on freqs_cis, got {freqs_cis.placements}" | ||||||
| positions = DTensor.from_local( | ||||||
| positions, | ||||||
| freqs_cis.device_mesh, | ||||||
| freqs_cis.placements, | ||||||
|
||||||
| run_check=False, | ||||||
| ) | ||||||
| return positions | ||||||
|
|
||||||
|
|
||||||
| # TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex | ||||||
| def apply_rotary_emb_complex( | ||||||
| xq: torch.Tensor, | ||||||
|
|
@@ -304,6 +334,7 @@ def apply_rotary_emb_complex( | |||||
| freqs_cis: (max_seqlen, head_dim // 2) complex | ||||||
| positions: optional position indices | ||||||
| """ | ||||||
| positions = _maybe_wrap_positions(positions, freqs_cis) | ||||||
|
||||||
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||||||
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||||||
| freqs_cis = _reshape_for_broadcast_complex(freqs_cis, xq_, positions) | ||||||
|
|
@@ -324,6 +355,7 @@ def apply_rotary_emb_single_complex( | |||||
| freqs_cis: (max_seqlen, head_dim // 2) complex | ||||||
| positions: optional position indices | ||||||
| """ | ||||||
| positions = _maybe_wrap_positions(positions, freqs_cis) | ||||||
| dtype = x.dtype | ||||||
| x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) | ||||||
| freqs_cis = _reshape_for_broadcast_complex(freqs_cis, x, positions) | ||||||
|
|
@@ -345,6 +377,7 @@ def apply_rotary_emb_cos_sin( | |||||
| rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated | ||||||
| positions: optional position indices | ||||||
| """ | ||||||
| positions = _maybe_wrap_positions(positions, rope_cache) | ||||||
| head_dim = xq.shape[-1] | ||||||
| rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions) | ||||||
| cos = rope_cache[..., :head_dim].to(device=xq.device) | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.