Skip to content

Commit 272799c

Browse files
committed
online mode
1 parent 9ef7f64 commit 272799c

File tree

2 files changed

+63
-13
lines changed

2 files changed

+63
-13
lines changed

scripts/train_eagle3.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
get_dp_group,
3737
get_draft_dp_group,
3838
get_tp_group,
39-
init_distributed,
39+
init_distributed, get_draft_sp_group,
4040
)
4141
from specforge.modeling.target import (
4242
Eagle3TargetModel,
@@ -335,10 +335,6 @@ def sanity_check(args: Namespace) -> None:
335335
args.draft_accumulation_steps = (
336336
args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size
337337
)
338-
if args.attention_backend in ("usp", "usp_fa"):
339-
assert (
340-
args.train_hidden_states_path is not None
341-
), "train_hidden_states_path should not be None for usp"
342338

343339

344340
def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]:
@@ -410,6 +406,9 @@ def build_dataloaders(
410406
)
411407
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
412408
train_dataset = load_dataset("json", data_files=args.train_data_path)["train"]
409+
is_online = (
410+
args.train_data_path is not None and args.train_hidden_states_path is None
411+
)
413412
with rank_0_priority():
414413
train_eagle3_dataset = build_eagle3_dataset(
415414
dataset=train_dataset,
@@ -431,7 +430,7 @@ def build_dataloaders(
431430
cache_key=cache_key,
432431
)
433432

434-
if args.train_hidden_states_path is not None:
433+
if not is_online:
435434
train_eagle3_dataset = build_offline_eagle3_dataset(
436435
args.train_hidden_states_path,
437436
args.max_length,
@@ -444,7 +443,7 @@ def build_dataloaders(
444443
shuffle=True,
445444
process_group=(
446445
get_draft_dp_group()
447-
if args.attention_backend == "usp"
446+
if args.attention_backend == "usp" and not is_online
448447
else get_dp_group()
449448
),
450449
is_vlm=args.is_vlm,
@@ -475,7 +474,7 @@ def build_dataloaders(
475474
shuffle=False,
476475
process_group=(
477476
get_draft_dp_group()
478-
if args.attention_backend == "usp"
477+
if args.attention_backend == "usp" and not is_online
479478
else get_dp_group()
480479
),
481480
is_vlm=args.is_vlm,
@@ -632,13 +631,59 @@ def record_metrcs(
632631
tracker.log(logdict, step=global_step)
633632

634633

635-
def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor:
634+
import torch
635+
import torch.distributed as dist
636+
import torch.nn.functional as F
637+
638+
639+
def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor:
636640
"""
637-
Get the data shard from the tensor.
641+
Process: TP split -> Pad to Max Len -> SP gather.
638642
"""
639-
tp_size = dist.get_world_size(get_tp_group())
640-
tp_rank = dist.get_rank(get_tp_group())
641-
return tensor.chunk(tp_size, dim=0)[tp_rank]
643+
# 1. TP: Slice the tensor along the batch dimension
644+
tp_group = get_tp_group()
645+
tp_size = dist.get_world_size(tp_group)
646+
tp_rank = dist.get_rank(tp_group)
647+
648+
local_tp_shard = tensor.chunk(tp_size, dim=0)[tp_rank]
649+
650+
# 2. SP: Handle dynamic sequence lengths and Gather
651+
sp_group = get_draft_sp_group()
652+
653+
if sp_group is not None and dist.get_world_size(sp_group) > 1:
654+
sp_world_size = dist.get_world_size(sp_group)
655+
656+
# --- Fix for Variable Sequence Lengths ---
657+
local_seq_len = local_tp_shard.size(sp_dim)
658+
659+
# Find global max sequence length in SP group
660+
len_tensor = torch.tensor([local_seq_len], device=local_tp_shard.device, dtype=torch.long)
661+
dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group)
662+
max_seq_len = len_tensor.item()
663+
664+
# Pad local tensor if necessary
665+
# Assuming shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1
666+
if local_seq_len < max_seq_len:
667+
pad_size = max_seq_len - local_seq_len
668+
669+
# Construct pad tuple for F.pad (applies from last dim backwards)
670+
# Initialize with all zeros (no padding for other dims)
671+
pad_config = [0] * (local_tp_shard.ndim * 2)
672+
673+
pad_idx = (local_tp_shard.ndim - 1 - sp_dim) * 2 + 1
674+
pad_config[pad_idx] = pad_size
675+
676+
# Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
677+
local_tp_shard_padded = F.pad(local_tp_shard, pad_config, value=0)
678+
else:
679+
local_tp_shard_padded = local_tp_shard
680+
681+
gathered_shards = [torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)]
682+
dist.all_gather(gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group)
683+
684+
return torch.cat(gathered_shards, dim=sp_dim)
685+
686+
return local_tp_shard
642687

643688

644689
def main():

specforge/layers/ring/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,8 @@
55
ring_flash_attn_qkvpacked_func,
66
)
77

8+
__all__ = [
9+
"ring_flash_attn_func",
10+
"ring_flash_attn_kvpacked_func",
11+
"ring_flash_attn_qkvpacked_func",
12+
]

0 commit comments

Comments
 (0)