Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def parse_args():
"sharegpt",
"eaglechat",
"perfectblend",
"perfectblend-llama3.1-8b-instruct",
"perfectblend-llama3.3-70b-instruct",
"perfectblend-llama4-scout-instruct",
"perfectblend-llama4-maverick-instruct",
"magpie-qwen2.5-pro-1m-v0.1",
"sharegpt4v",
"allava4v",
Expand Down Expand Up @@ -189,20 +193,26 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
total_skipped_count = 0
with open(train_output_jsonl_path, "w") as f:
for item in tqdm(train_ds, desc=f"Processing {dataset_name} dataset"):
row, skipped_count = proc_fn(item)
if row is None:
continue
total_skipped_count += skipped_count
if proc_fn is not None:
row, skipped_count = proc_fn(item)
if row is None:
continue
total_skipped_count += skipped_count
else:
row = item
f.write(json.dumps(row, ensure_ascii=False) + "\n")

if test_ds is not None:
test_output_jsonl_path = output_path.joinpath(f"{dataset_name}_test.jsonl")
with open(test_output_jsonl_path, "w") as f:
for item in tqdm(test_ds, desc=f"Processing {dataset_name} test dataset"):
row, skipped_count = proc_fn(item)
if row is None:
continue
total_skipped_count += skipped_count
if proc_fn is not None:
row, skipped_count = proc_fn(item)
if row is None:
continue
total_skipped_count += skipped_count
else:
row = item
f.write(json.dumps(row, ensure_ascii=False) + "\n")
Comment on lines 194 to 216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the processing loops for the training dataset (lines 195-203) and the test dataset (lines 208-216). This makes the code harder to maintain, as any changes would need to be applied in two places.

To improve this, you could extract the common logic into a helper function. This function would take a dataset, a file handle, and the processing function (proc_fn) as arguments, and would contain the loop and writing logic. This would make process_and_save_ds cleaner and more maintainable.


if total_skipped_count > 0:
Expand Down Expand Up @@ -252,6 +262,30 @@ def main():
ds = load_dataset("mlabonne/open-perfectblend")["train"]
ds = ds.map(add_index, with_indices=True)
proc_fn = process_sharegpt_row
elif args.dataset == "perfectblend-llama3.1-8b-instruct":
ds = load_dataset("frankleeeee/PerfectBlend-Regenerated-Llama-3.1-8B-Instruct")[
"train"
]
ds = ds.map(add_index, with_indices=True)
proc_fn = None
elif args.dataset == "perfectblend-llama3.3-70b-instruct":
ds = load_dataset(
"frankleeeee/PerfectBlend-Regenerated-Llama-3.3-70B-Instruct"
)["train"]
ds = ds.map(add_index, with_indices=True)
proc_fn = None
elif args.dataset == "perfectblend-llama4-scout-instruct":
ds = load_dataset(
"frankleeeee/PerfectBlend-Regenerated-Llama-4-Scout-17B-16E-Instruct"
)["train"]
ds = ds.map(add_index, with_indices=True)
proc_fn = None
elif args.dataset == "perfectblend-llama4-maverick-instruct":
ds = load_dataset(
"frankleeeee/PerfectBlend-Regenerated-Llama-4-Maverick-17B-128E-Instruct"
)["train"]
ds = ds.map(add_index, with_indices=True)
proc_fn = None
Comment on lines +265 to +288
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The elif chain for handling the new perfectblend-llama* datasets contains a lot of repetitive code. Each block loads a dataset from a specific Hugging Face repository, maps add_index, and sets proc_fn to None.

To make this more maintainable and easier to extend, you could use a dictionary to map the dataset names to their repository IDs. This would consolidate the logic into a single block.

Suggested change
elif args.dataset == "perfectblend-llama3.1-8b-instruct":
ds = load_dataset("frankleeeee/PerfectBlend-Regenerated-Llama-3.1-8B-Instruct")[
"train"
]
ds = ds.map(add_index, with_indices=True)
proc_fn = None
elif args.dataset == "perfectblend-llama3.3-70b-instruct":
ds = load_dataset(
"frankleeeee/PerfectBlend-Regenerated-Llama-3.3-70B-Instruct"
)["train"]
ds = ds.map(add_index, with_indices=True)
proc_fn = None
elif args.dataset == "perfectblend-llama4-scout-instruct":
ds = load_dataset(
"frankleeeee/PerfectBlend-Regenerated-Llama-4-Scout-17B-16E-Instruct"
)["train"]
ds = ds.map(add_index, with_indices=True)
proc_fn = None
elif args.dataset == "perfectblend-llama4-maverick-instruct":
ds = load_dataset(
"frankleeeee/PerfectBlend-Regenerated-Llama-4-Maverick-17B-128E-Instruct"
)["train"]
ds = ds.map(add_index, with_indices=True)
proc_fn = None
elif args.dataset in [
"perfectblend-llama3.1-8b-instruct",
"perfectblend-llama3.3-70b-instruct",
"perfectblend-llama4-scout-instruct",
"perfectblend-llama4-maverick-instruct",
]:
dataset_map = {
"perfectblend-llama3.1-8b-instruct": "frankleeeee/PerfectBlend-Regenerated-Llama-3.1-8B-Instruct",
"perfectblend-llama3.3-70b-instruct": "frankleeeee/PerfectBlend-Regenerated-Llama-3.3-70B-Instruct",
"perfectblend-llama4-scout-instruct": "frankleeeee/PerfectBlend-Regenerated-Llama-4-Scout-17B-16E-Instruct",
"perfectblend-llama4-maverick-instruct": "frankleeeee/PerfectBlend-Regenerated-Llama-4-Maverick-17B-128E-Instruct",
}
repo_id = dataset_map[args.dataset]
ds = load_dataset(repo_id)["train"]
ds = ds.map(add_index, with_indices=True)
proc_fn = None

elif args.dataset == "magpie-qwen2.5-pro-1m-v0.1":
ds = load_dataset("Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1")["train"]
ds = ds.rename_column("uuid", "id")
Expand Down