Skip to content

Commit ffc4ab7

Browse files
authored
Feat/dflash training improvements (#463)
* feat(dflash): add random anchor sampling, loss decay, and sync with upstream - Add random anchor sampling for block construction (paper Sec 4.2) - Add exponential loss decay weighting (paper Sec 4.2, Eq.4, Appendix A.3.1) - Sync with upstream: dflash_config in config.json, mask_token_id from config, decoupled target_layer_ids - Align training hyperparams with paper (lr=6e-4, warmup=0.04, epochs=6, max_length=3072) - Fix auto_map and saved model file name for HuggingFace compatibility * fix(dflash): per-sample anchor sampling with padding block isolation - Sample anchors independently per batch sample (max strategy) - Mark padding blocks with block_id=-1 for attention isolation - Padding blocks excluded from both attention and loss computation - Use absolute positions from gather_idx for position encoding - Per-sample block_ids throughout: attention mask, loss mask, noise input * fix(dflash): align acceptance rate metric with inference and trust_remote_code kwarg - Use loss_mask to exclude prompt tokens from acc calculation, only measure on completion/assistant blocks to match inference behavior - Replace token-level accuracy with cumprod-based acceptance length - Clean up debug prints and redundant comments - Add trust_remote_code kwarg
1 parent 6c27152 commit ffc4ab7

File tree

8 files changed

+351
-103
lines changed

8 files changed

+351
-103
lines changed

configs/longcat-flash-dflash.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
"attention_bias": false,
66
"attention_dropout": 0.0,
77
"auto_map": {
8-
"AutoModel": "modeling_dflash.DFlashDraftModel"
8+
"AutoModel": "dflash.DFlashDraftModel"
99
},
1010
"block_size": 16,
1111
"bos_token_id": 1,
12+
"dflash_config": {
13+
"mask_token_id": 2,
14+
"target_layer_ids": [1, 7, 13, 19, 25]
15+
},
1216
"dtype": "bfloat16",
1317
"eos_token_id": 2,
1418
"head_dim": 128,

configs/qwen3-8b-dflash.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
"attention_bias": false,
66
"attention_dropout": 0.0,
77
"auto_map": {
8-
"AutoModel": "modeling_dflash.DFlashDraftModel"
8+
"AutoModel": "dflash.DFlashDraftModel"
99
},
1010
"block_size": 16,
1111
"bos_token_id": 151643,
12+
"dflash_config": {
13+
"mask_token_id": 151669,
14+
"target_layer_ids": [1, 9, 17, 25, 33]
15+
},
1216
"dtype": "bfloat16",
1317
"eos_token_id": 151645,
1418
"head_dim": 128,

examples/run_longcat_flash_dflash_online.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,16 @@ torchrun \
2626
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
2727
--build-dataset-num-proc $BUILD_DATASET_NUM_PROC \
2828
--output-dir $ROOT_DIR/outputs/longcat-flash-dflash-sharegpt \
29-
--num-epochs 20 \
29+
--num-epochs 6 \
3030
--batch-size 2 \
31-
--learning-rate 1e-4 \
32-
--max-length 2048 \
31+
--learning-rate 6e-4 \
32+
--warmup-ratio 0.04 \
33+
--max-grad-norm 1.0 \
34+
--max-length 3072 \
3335
--chat-template longcat \
36+
--random-anchor \
37+
--num-anchors 512 \
38+
--loss-decay-gamma 7.0 \
3439
--log-interval 50 \
3540
--save-interval 1000 \
3641
--report-to wandb \

examples/run_qwen3_8b_dflash_online.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@ torchrun \
1616
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json \
1717
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
1818
--output-dir $ROOT_DIR/outputs/qwen3-8b-dflash-sharegpt \
19-
--num-epochs 20 \
19+
--num-epochs 6 \
2020
--batch-size 4 \
21-
--learning-rate 1e-4 \
22-
--max-length 2048 \
21+
--learning-rate 6e-4 \
22+
--warmup-ratio 0.04 \
23+
--max-grad-norm 1.0 \
24+
--max-length 3072 \
2325
--chat-template qwen \
2426
--attention-backend $ATTENTION_BACKEND \
27+
--random-anchor \
28+
--num-anchors 512 \
29+
--loss-decay-gamma 7.0 \
2530
--log-interval 50 \
2631
--save-interval 1000 \
2732
--report-to wandb \

scripts/train_dflash.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,24 @@ def parse_args():
6767
model_group.add_argument(
6868
"--trust-remote-code", action="store_true", help="Trust remote code"
6969
)
70+
model_group.add_argument(
71+
"--random-anchor",
72+
action="store_true",
73+
help="Enable random anchor sampling for block construction (paper Sec 4.2).",
74+
)
75+
model_group.add_argument(
76+
"--num-anchors",
77+
type=int,
78+
default=512,
79+
help="Number of anchor positions per sequence when --random-anchor is set.",
80+
)
81+
model_group.add_argument(
82+
"--loss-decay-gamma",
83+
type=float,
84+
default=None,
85+
help="Gamma for exponential loss decay weighting (paper Eq.4). "
86+
"Suggested: 7 for block_size=16, 5 for 10, 4 for 8. None disables.",
87+
)
7088

7189
dataset_group = parser.add_argument_group("dataset")
7290
dataset_group.add_argument("--train-data-path", type=str, required=True)
@@ -81,11 +99,11 @@ def parse_args():
8199
)
82100

83101
training_group = parser.add_argument_group("training")
84-
training_group.add_argument("--num-epochs", type=int, default=3)
102+
training_group.add_argument("--num-epochs", type=int, default=6)
85103
training_group.add_argument("--batch-size", type=int, default=1)
86-
training_group.add_argument("--learning-rate", type=float, default=1e-4)
87-
training_group.add_argument("--max-length", type=int, default=2048)
88-
training_group.add_argument("--warmup-ratio", type=float, default=0.01)
104+
training_group.add_argument("--learning-rate", type=float, default=6e-4)
105+
training_group.add_argument("--max-length", type=int, default=3072)
106+
training_group.add_argument("--warmup-ratio", type=float, default=0.04)
89107
training_group.add_argument("--max-grad-norm", type=float, default=1.0)
90108
training_group.add_argument("--accumulation-steps", type=int, default=1)
91109
training_group.add_argument("--seed", type=int, default=42)
@@ -152,6 +170,10 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
152170
draft_config.num_target_layers = target_config.num_hidden_layers
153171
print_on_rank0("Auto-generated draft config from target model")
154172

173+
# Ensure dflash_config exists in config (for target_layer_ids / mask_token_id)
174+
if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None:
175+
draft_config.dflash_config = {}
176+
155177
# Set attention implementation based on backend
156178
draft_config._attn_implementation = args.attention_backend
157179
print_on_rank0(f"Using attention backend: {args.attention_backend}")
@@ -265,7 +287,7 @@ def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
265287

266288
draft_model.save_pretrained(save_dir, state_dict=draft_state_dict)
267289

268-
# Copy modeling_dflash.py for inference compatibility
290+
# Copy dflash.py for inference compatibility (matches auto_map in config)
269291
modeling_src = os.path.join(
270292
os.path.dirname(__file__),
271293
"..",
@@ -274,7 +296,7 @@ def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
274296
"draft",
275297
"dflash.py",
276298
)
277-
modeling_dst = os.path.join(save_dir, "modeling_dflash.py")
299+
modeling_dst = os.path.join(save_dir, "dflash.py")
278300
if os.path.exists(modeling_src):
279301
shutil.copy(modeling_src, modeling_dst)
280302

@@ -344,6 +366,14 @@ def main():
344366
mask_token_id = tokenizer.mask_token_id
345367
print_on_rank0(f"Using mask_token_id: {mask_token_id}")
346368

369+
# Write mask_token_id and target_layer_ids into draft config so that
370+
# save_pretrained produces a config.json compatible with the official
371+
# dflash inference code (which reads from config.dflash_config).
372+
draft_model.mask_token_id = mask_token_id
373+
draft_model.config.dflash_config["mask_token_id"] = mask_token_id
374+
draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids
375+
print_on_rank0(f"dflash_config: {draft_model.config.dflash_config}")
376+
347377
train_dataloader, eval_dataloader = build_dataloader(args, tokenizer)
348378

349379
steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps)
@@ -369,6 +399,9 @@ def main():
369399
block_size=draft_model.block_size,
370400
mask_token_id=mask_token_id,
371401
attention_backend=args.attention_backend,
402+
random_anchor=args.random_anchor,
403+
num_anchors=args.num_anchors,
404+
loss_decay_gamma=args.loss_decay_gamma,
372405
)
373406

374407
dflash_model = FSDP(

0 commit comments

Comments
 (0)