Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added requirements.txt
Empty file.
69 changes: 57 additions & 12 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
destroy_distributed,
get_dp_group,
get_draft_dp_group,
get_draft_sp_group,
get_tp_group,
init_distributed,
)
Expand Down Expand Up @@ -335,10 +336,6 @@ def sanity_check(args: Namespace) -> None:
args.draft_accumulation_steps = (
args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size
)
if args.attention_backend == "usp":
assert (
args.train_hidden_states_path is not None
), "train_hidden_states_path should not be None for usp"


def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]:
Expand Down Expand Up @@ -410,6 +407,9 @@ def build_dataloaders(
)
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
train_dataset = load_dataset("json", data_files=args.train_data_path)["train"]
is_online = (
args.train_data_path is not None and args.train_hidden_states_path is None
)
with rank_0_priority():
train_eagle3_dataset = build_eagle3_dataset(
dataset=train_dataset,
Expand All @@ -431,7 +431,7 @@ def build_dataloaders(
cache_key=cache_key,
)

if args.train_hidden_states_path is not None:
if not is_online:
train_eagle3_dataset = build_offline_eagle3_dataset(
args.train_hidden_states_path,
args.max_length,
Expand All @@ -443,7 +443,9 @@ def build_dataloaders(
num_workers=args.dataloader_num_workers,
shuffle=True,
process_group=(
get_draft_dp_group() if args.attention_backend == "usp" else get_dp_group()
get_draft_dp_group()
if args.attention_backend == "usp" and not is_online
else get_dp_group()
),
is_vlm=args.is_vlm,
)
Expand Down Expand Up @@ -473,7 +475,7 @@ def build_dataloaders(
shuffle=False,
process_group=(
get_draft_dp_group()
if args.attention_backend == "usp"
if args.attention_backend == "usp" and not is_online
else get_dp_group()
),
is_vlm=args.is_vlm,
Expand Down Expand Up @@ -630,13 +632,56 @@ def record_metrcs(
tracker.log(logdict, step=global_step)


def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor:
def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor:
"""
Get the data shard from the tensor.
Process: TP split -> Pad to Max Len -> SP gather.
"""
tp_size = dist.get_world_size(get_tp_group())
tp_rank = dist.get_rank(get_tp_group())
return tensor.chunk(tp_size, dim=0)[tp_rank]
# 1. TP: Slice the tensor along the batch dimension
tp_group = get_tp_group()
tp_size = dist.get_world_size(tp_group)
tp_rank = dist.get_rank(tp_group)

local_tp_shard = tensor.chunk(tp_size, dim=0)[tp_rank]

# 2. SP: Handle dynamic sequence lengths and Gather
sp_group = get_draft_sp_group()

if sp_group is not None and dist.get_world_size(sp_group) > 1:
sp_world_size = dist.get_world_size(sp_group)
local_seq_len = local_tp_shard.size(sp_dim)

# Find global max sequence length in SP group
len_tensor = torch.tensor(
[local_seq_len], device=local_tp_shard.device, dtype=torch.long
)
dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group)
max_seq_len = len_tensor.item()

# Pad local tensor if necessary
# Shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1
if local_seq_len < max_seq_len:
pad_size = max_seq_len - local_seq_len

pad_config = [0] * (local_tp_shard.ndim * 2)

pad_idx = (local_tp_shard.ndim - 1 - sp_dim) * 2 + 1
pad_config[pad_idx] = pad_size

# Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0)
else:
local_tp_shard_padded = local_tp_shard

gathered_shards = [
torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)
]
dist.all_gather(
gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group
)

return torch.cat(gathered_shards, dim=sp_dim)

return local_tp_shard


def main():
Expand Down
12 changes: 12 additions & 0 deletions specforge/layers/ring/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# adapt from https://github.com/feifeibear/long-context-attention/tree/main/yunchang
from .ring_flash_attn import (
ring_flash_attn_func,
ring_flash_attn_kvpacked_func,
ring_flash_attn_qkvpacked_func,
)

__all__ = [
"ring_flash_attn_func",
"ring_flash_attn_kvpacked_func",
"ring_flash_attn_qkvpacked_func",
]
Loading