Skip to content

Commit 14ffd72

Browse files
committed
feat(dflash): add checkpoint resume support and clean up comments
1 parent ffc4ab7 commit 14ffd72

File tree

2 files changed

+66
-30
lines changed

2 files changed

+66
-30
lines changed

examples/run_qwen3_8b_dflash_online.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
44
ROOT_DIR=$(dirname $SCRIPT_DIR)
55
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
66
export SPECFORGE_DATA_NUM_PROC=32
7-
NUM_GPUS=${1:-1}
7+
NUM_GPUS=${1:-8}
88

99
ATTENTION_BACKEND=${2:-flex_attention}
1010

@@ -14,8 +14,8 @@ torchrun \
1414
$ROOT_DIR/scripts/train_dflash.py \
1515
--target-model-path Qwen/Qwen3-8B \
1616
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json \
17-
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
18-
--output-dir $ROOT_DIR/outputs/qwen3-8b-dflash-sharegpt \
17+
--train-data-path $ROOT_DIR/cache/dataset/perfectblend_qwen3-8b_regen.jsonl \
18+
--output-dir $ROOT_DIR/outputs/qwen3-8b-perfectblend \
1919
--num-epochs 6 \
2020
--batch-size 4 \
2121
--learning-rate 6e-4 \
@@ -31,4 +31,4 @@ torchrun \
3131
--save-interval 1000 \
3232
--report-to wandb \
3333
--wandb-project specforge-qwen3-8b-dflash \
34-
--wandb-name qwen3-8b-dflash-sharegpt
34+
--wandb-name qwen3-8b-dflash-perfectblend

scripts/train_dflash.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
3434
from specforge.optimizer import BF16Optimizer
3535
from specforge.tracker import create_tracker
36-
from specforge.utils import print_on_rank0, print_with_rank
36+
from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank
3737

3838

3939
def parse_args():
@@ -108,6 +108,12 @@ def parse_args():
108108
training_group.add_argument("--accumulation-steps", type=int, default=1)
109109
training_group.add_argument("--seed", type=int, default=42)
110110
training_group.add_argument("--resume", action="store_true")
111+
training_group.add_argument(
112+
"--ckpt-dir",
113+
type=str,
114+
default=None,
115+
help="Directory of the checkpoint to resume training from",
116+
)
111117

112118
output_group = parser.add_argument_group("output")
113119
output_group.add_argument("--output-dir", type=str, required=True)
@@ -162,25 +168,21 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
162168
draft_config = AutoConfig.from_pretrained(args.draft_config_path)
163169
print_on_rank0(f"Loaded draft config from {args.draft_config_path}")
164170
else:
165-
# Load config from HF (needed for structure info even if backend is sglang)
166171
target_config = AutoConfig.from_pretrained(args.target_model_path)
167172
draft_config = AutoConfig.from_pretrained(args.target_model_path)
168173
draft_config.num_hidden_layers = args.num_draft_layers
169174
draft_config.block_size = args.block_size
170175
draft_config.num_target_layers = target_config.num_hidden_layers
171176
print_on_rank0("Auto-generated draft config from target model")
172177

173-
# Ensure dflash_config exists in config (for target_layer_ids / mask_token_id)
174178
if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None:
175179
draft_config.dflash_config = {}
176180

177-
# Set attention implementation based on backend
178181
draft_config._attn_implementation = args.attention_backend
179182
print_on_rank0(f"Using attention backend: {args.attention_backend}")
180183

181184
draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16)
182185

183-
# Set capture layers for target model based on draft model config
184186
target_model.set_capture_layers(draft_model.target_layer_ids)
185187

186188
print_on_rank0(
@@ -199,7 +201,6 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
199201
"""Build train and eval dataloaders."""
200202
import hashlib
201203

202-
# convert to dataloader
203204
cache_params_string = (
204205
f"{args.train_data_path}-"
205206
f"{args.max_length}-"
@@ -220,7 +221,6 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
220221
num_proc=args.build_dataset_num_proc,
221222
)
222223

223-
# Filter out samples with too few loss tokens (DFlash requires >= 2 * block_size)
224224
min_loss_tokens = 2 * args.block_size
225225
original_size = len(train_eagle3_dataset)
226226
train_eagle3_dataset = train_eagle3_dataset.filter(
@@ -287,7 +287,6 @@ def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
287287

288288
draft_model.save_pretrained(save_dir, state_dict=draft_state_dict)
289289

290-
# Copy dflash.py for inference compatibility (matches auto_map in config)
291290
modeling_src = os.path.join(
292291
os.path.dirname(__file__),
293292
"..",
@@ -331,16 +330,13 @@ def record_metrics(
331330

332331

333332
def main():
334-
# Configure logging to ensure we see INFO logs
333+
335334
logging.basicConfig(
336335
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
337336
datefmt="%m/%d/%Y %H:%M:%S",
338337
level=logging.INFO,
339338
)
340-
# Force the root logger to INFO as well, just in case
341339
logging.getLogger().setLevel(logging.INFO)
342-
343-
# Filter annoying FSDP warnings
344340
warnings.filterwarnings(
345341
"ignore",
346342
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed",
@@ -354,9 +350,45 @@ def main():
354350

355351
target_model, draft_model = build_models(args)
356352

353+
draft_model_last_checkpoint = None
354+
if args.ckpt_dir is not None:
355+
if os.path.isdir(args.ckpt_dir):
356+
draft_model_last_checkpoint = args.ckpt_dir
357+
print_on_rank0(f"Using checkpoint: {draft_model_last_checkpoint}")
358+
else:
359+
raise ValueError(
360+
f"Provided ckpt dir {args.ckpt_dir} is not a valid directory."
361+
)
362+
363+
if args.resume and os.path.isdir(args.output_dir):
364+
draft_model_last_checkpoint = get_last_checkpoint(
365+
args.output_dir, prefix=r"epoch_\d+_step"
366+
)
367+
print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}")
368+
369+
resume_state = None
370+
if draft_model_last_checkpoint:
371+
loaded_model = DFlashDraftModel.from_pretrained(
372+
draft_model_last_checkpoint, torch_dtype=torch.bfloat16
373+
)
374+
draft_model.load_state_dict(loaded_model.state_dict())
375+
del loaded_model
376+
print_on_rank0("Loaded draft model weights from checkpoint")
377+
378+
training_state_path = os.path.join(
379+
draft_model_last_checkpoint, "training_state.pt"
380+
)
381+
if os.path.exists(training_state_path):
382+
resume_state = torch.load(
383+
training_state_path, map_location="cpu", weights_only=False
384+
)
385+
print_on_rank0(
386+
f"Will resume from epoch {resume_state['epoch']}, "
387+
f"step {resume_state['global_step']}"
388+
)
389+
357390
tokenizer = AutoTokenizer.from_pretrained(args.target_model_path)
358391

359-
# Get mask_token_id
360392
if args.mask_token_id is not None:
361393
mask_token_id = args.mask_token_id
362394
elif tokenizer.mask_token_id is not None:
@@ -366,9 +398,6 @@ def main():
366398
mask_token_id = tokenizer.mask_token_id
367399
print_on_rank0(f"Using mask_token_id: {mask_token_id}")
368400

369-
# Write mask_token_id and target_layer_ids into draft config so that
370-
# save_pretrained produces a config.json compatible with the official
371-
# dflash inference code (which reads from config.dflash_config).
372401
draft_model.mask_token_id = mask_token_id
373402
draft_model.config.dflash_config["mask_token_id"] = mask_token_id
374403
draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids
@@ -380,10 +409,7 @@ def main():
380409
total_steps = args.num_epochs * steps_per_epoch
381410
print_on_rank0(f"Total training steps: {total_steps}")
382411

383-
# Note: We need embedding layer for DFlash wrapper.
384-
# For SGLang backend, we can't easily get the embedding layer object.
385-
# We use TargetEmbeddingsAndHead to efficiently load only needed weights.
386-
print_on_rank0("Loading target embeddings and head efficiently...")
412+
print_on_rank0("Loading target embeddings and head...")
387413
target_components = TargetEmbeddingsAndHead.from_pretrained(
388414
args.target_model_path,
389415
embed_key="model.embed_tokens.weight", # Adjust if Qwen/Llama differs
@@ -423,14 +449,25 @@ def main():
423449
total_steps=total_steps,
424450
)
425451

452+
start_epoch = 0
453+
global_step = 0
454+
if resume_state is not None:
455+
optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"])
456+
start_epoch = resume_state["epoch"]
457+
global_step = resume_state["global_step"]
458+
del resume_state
459+
print_on_rank0(f"Restored scheduler, lr={optimizer.get_learning_rate():.6f}")
460+
461+
skip_steps = global_step - start_epoch * len(train_dataloader)
462+
426463
print_on_rank0(f"Initializing tracker (report_to={args.report_to})...")
427464
tracker = create_tracker(args, args.output_dir)
428465
print_on_rank0("Tracker initialized successfully.")
429466

430-
global_step = 0
431467
last_time = time.time()
468+
print_on_rank0(f"Starting training from epoch {start_epoch}, step {global_step}")
432469

433-
for epoch in range(args.num_epochs):
470+
for epoch in range(start_epoch, args.num_epochs):
434471
train_dataloader.sampler.set_epoch(epoch)
435472
draft_model.train()
436473

@@ -441,21 +478,20 @@ def main():
441478
else:
442479
progress_bar = train_dataloader
443480

444-
for data in progress_bar:
481+
for step_in_epoch, data in enumerate(progress_bar):
482+
if epoch == start_epoch and step_in_epoch < skip_steps:
483+
continue
445484
global_step += 1
446485

447486
input_ids = data["input_ids"].cuda()
448487
attention_mask = data["attention_mask"].cuda()
449488
loss_mask = data["loss_mask"].cuda()
450489

451-
# Generate context from Target Model (SGLang or HF)
452-
# This calls the backend to get hidden states
453490
target_output = target_model.generate_dflash_data(
454491
input_ids, attention_mask, loss_mask
455492
)
456493
hidden_states = target_output.hidden_states.cuda() # Ensure on GPU
457494

458-
# Forward pass (Parallel Training)
459495
loss, accuracy = dflash_model(
460496
input_ids=input_ids,
461497
attention_mask=attention_mask,

0 commit comments

Comments
 (0)