Skip to content

Commit 29a6ad5

Browse files
yizhang2077sleepcoo
authored andcommitted
support qwen3 coder draft model, add opc dataset process (sgl-project#73)
* support qwen3 coder draft model, add opc dataset process * rename config --------- Co-authored-by: lukec <118525388+sleepcoo@users.noreply.github.com>
1 parent 391cd8e commit 29a6ad5

File tree

6 files changed

+76
-1
lines changed

6 files changed

+76
-1
lines changed

benchmarks/run_gsm8k.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def few_shot_gsm8k(s, question):
8484
states = few_shot_gsm8k.run_batch(
8585
arguments,
8686
temperature=0,
87+
max_new_tokens=2048,
8788
num_threads=args.parallel,
8889
progress_bar=True,
8990
)

benchmarks/run_humaneval.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def get_humaneval_answer(s, question):
6868
states = get_humaneval_answer.run_batch(
6969
questions,
7070
temperature=0,
71+
max_new_tokens=2048,
7172
num_threads=args.parallel,
7273
progress_bar=True,
7374
)
@@ -77,6 +78,7 @@ def get_humaneval_answer(s, question):
7778
num_output_tokens = sum(
7879
s.get_meta_info("answer")["completion_tokens"] for s in states
7980
)
81+
8082
output_throughput = num_output_tokens / latency
8183

8284
has_verify = "spec_verify_ct" in states[0].get_meta_info("answer")

benchmarks/run_math500.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def get_humaneval_answer(s, question):
3939
states = get_humaneval_answer.run_batch(
4040
questions,
4141
temperature=0,
42+
max_new_tokens=2048,
4243
num_threads=args.parallel,
4344
progress_bar=True,
4445
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"architectures": [
3+
"LlamaForCausalLMEagle3"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 151643,
8+
"eos_token_id": 151645,
9+
"head_dim": 128,
10+
"hidden_act": "silu",
11+
"hidden_size": 6144,
12+
"initializer_range": 0.02,
13+
"intermediate_size": 16384,
14+
"max_position_embeddings": 262144,
15+
"max_window_layers": 62,
16+
"model_type": "llama",
17+
"num_attention_heads": 96,
18+
"num_hidden_layers": 1,
19+
"num_key_value_heads":8,
20+
"rms_norm_eps": 1e-06,
21+
"rope_scaling": null,
22+
"rope_theta": 1000000,
23+
"sliding_window": null,
24+
"tie_word_embeddings": false,
25+
"torch_dtype": "bfloat16",
26+
"transformers_version": "4.51.0",
27+
"use_cache": true,
28+
"use_sliding_window": false,
29+
"vocab_size": 151936,
30+
"draft_vocab_size": 32000
31+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
ROOT_DIR=$(dirname $SCRIPT_DIR)
3+
4+
# train eagle3 for qwen3-coder
5+
NUM_GPUS=${1:-8}
6+
7+
torchrun \
8+
--standalone \
9+
--nproc_per_node $NUM_GPUS \
10+
$ROOT_DIR/scripts/train_eagle3_offline.py \
11+
--target-model-path Qwen/Qwen3-Coder-480B-A35B-Instruct \
12+
--draft-model-config $ROOT_DIR/configs/qwen3-coder-480B-A35B-instruct-eagle3.json \
13+
--train-data-path $ROOT_DIR/cache/dataset/opc.jsonl \
14+
--train-hidden-states-path $ROOT_DIR/cache/hidden_states \
15+
--output-dir $ROOT_DIR/outputs/Qwen3-Coder-480B-A35B-Instruct \
16+
--num-epochs 10 \
17+
--batch-size 1 \
18+
--learning-rate 1e-4 \
19+
--max-length 2048 \
20+
--chat-template qwen \
21+
--resume

scripts/prepare_data.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def parse_args():
3434
parser.add_argument(
3535
"--dataset",
3636
type=str,
37-
choices=["ultrachat", "sharegpt"],
37+
choices=["ultrachat", "sharegpt", "opc"],
3838
help="The demo dataset to quickly run the training for speculative decoding",
3939
)
4040
parser.add_argument(
@@ -108,6 +108,20 @@ def load_dataset_from_path(data_path: Path):
108108
return ds
109109

110110

111+
import hashlib
112+
113+
114+
def process_opc_sft_stage1(row) -> Dict:
115+
row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest()
116+
return {
117+
"id": row_id,
118+
"conversations": [
119+
{"role": "user", "content": row["instruction"]},
120+
{"role": "assistant", "content": row["output"]},
121+
],
122+
}
123+
124+
111125
def main():
112126
args = parse_args()
113127
# load dataset
@@ -121,6 +135,11 @@ def main():
121135
print("Loading dataset from custom data path: ", args.data_path)
122136
ds = load_dataset_from_path(Path(args.data_path))
123137
proc_fn = process_sharegpt_row
138+
elif args.dataset == "opc":
139+
ds = load_dataset(
140+
"OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct"
141+
)["train"]
142+
proc_fn = process_opc_sft_stage1
124143
else:
125144
raise ValueError(
126145
f"This script only supports ultrachat_200k and sharegpt datasets for demo purpose, if you wish to use other datasets, please modify this script."

0 commit comments

Comments
 (0)