Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ You need to specify the following arguments:

### 🤩 Prepare your own dataset

Besides the provided ShareGPT/Ultrachat datasets, you can also prepare your own dataset. You should prepare the dataset in jsonl format and the schema should look like this:
Besides the provided ShareGPT/Ultrachat datasets, you can also prepare your own dataset. We support two formats:

#### Option 1: Conversation Format

You should prepare the dataset in jsonl format and the schema should look like this:

```json
{
Expand All @@ -134,6 +138,30 @@ Besides the provided ShareGPT/Ultrachat datasets, you can also prepare your own
}
```

#### Option 2: Pre-formatted Text Format

If you already have conversations formatted with a specific chat template, you can use the pre-formatted text directly:

```json
{
"id": "xxxx",
"text": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there!<|im_end|>\n"
}
```

This format is useful when you have pre-formatted prompts that were used during training of the target model and have raw generations from the target model.

To use pre-formatted datasets, add the `--is-preformatted` flag to your training command. Note that the `--chat-template` parameter is still needed and should match the template used in your pre-formatted text, as it is used to identify user/assistant tokens to determine the assistant spans and generate the corresponding loss mask.

```bash
torchrun --standalone --nproc_per_node 8 \
scripts/train_eagle3_online.py \
--is-preformatted \
--chat-template qwen \
--train-data-path ./your_preformatted_dataset.jsonl \
# ... other arguments
```

Once you have the `jsonl` file ready, you can go straight for online training or hidden states generation for offline training.

If you have multiple datasets, you can just merge them into the one jsonl file. For example, you can do something like this
Expand Down
6 changes: 6 additions & 0 deletions scripts/train_eagle3_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def parse_args():
parser.add_argument("--draft-attention-backend", type=str, default="flex_attention")
# data processing type
parser.add_argument("--chat-template", type=str, default="llama3")
parser.add_argument(
"--is-preformatted",
action="store_true",
help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.",
)

# distributed training
parser.add_argument("--tp-size", type=int, default=1)
Expand Down Expand Up @@ -247,6 +252,7 @@ def main():
dataset=train_dataset,
tokenizer=tokenizer,
chat_template=args.chat_template,
is_preformatted=args.is_preformatted,
max_length=args.max_length,
cache_dir=os.path.join(args.cache_dir, "processed_dataset"),
cache_key=cache_key,
Expand Down
8 changes: 8 additions & 0 deletions scripts/train_eagle3_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def parse_args():

# data processing type
parser.add_argument("--chat-template", type=str, default="llama3")
parser.add_argument(
"--is-preformatted",
action="store_true",
help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.",
)

# distributed training
parser.add_argument("--tp-size", type=int, default=1)
Expand Down Expand Up @@ -283,6 +288,7 @@ def main():
cache_dir=os.path.join(args.cache_dir, "processed_dataset"),
cache_key=cache_key,
is_vlm=args.is_vlm,
is_preformatted=args.is_preformatted,
processor=processor,
num_proc=args.build_dataset_num_proc,
)
Expand Down Expand Up @@ -317,6 +323,7 @@ def main():
is_vlm=args.is_vlm,
processor=processor,
num_proc=args.build_dataset_num_proc,
is_preformatted=args.is_preformatted,
)
eval_dataloader = prepare_dp_dataloaders(
eval_eagle3_dataset,
Expand All @@ -325,6 +332,7 @@ def main():
shuffle=False,
process_group=get_dp_group(),
is_vlm=args.is_vlm,
is_preformatted=args.is_preformatted,
)
print_with_rank("Initialized eval dataloader")

Expand Down
Loading