Skip to content

Commit 06162fc

Browse files
authored
Update finetune tests to decrease execution time (#1208)
Summary - Update such that we're using less data or fewer epochs - Decreases finetune test time from 20 minutes to about 4
1 parent 29ddedb commit 06162fc

File tree

5 files changed

+21
-25
lines changed

5 files changed

+21
-25
lines changed

tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_combined_datasets():
2323

2424
@pytest.mark.unit
2525
def test_separate_datasets():
26-
splits = {"train": "train[:10%]", "validation": "train[10%:20%]"}
26+
splits = {"train": "train[:5%]", "validation": "train[5%:7%]"}
2727
data_args = DatasetArguments(
2828
dataset="wikitext", dataset_config_name="wikitext-2-raw-v1"
2929
)

tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_no_padding_tokenization(self):
6767
op_manager = TextGenerationDataset.load_from_registry(
6868
self.data_args.dataset,
6969
data_args=self.data_args,
70-
split="train[5%:10%]",
70+
split="train[5%:7%]",
7171
processor=self.tiny_llama_tokenizer,
7272
)
7373
dataset = op_manager.load_dataset() # load
@@ -82,7 +82,7 @@ def test_no_padding_tokenization(self):
8282
ex_item = dataset[0]["text"]
8383
self.assertIn("Below is an instruction that describes a task", ex_item)
8484

85-
self.assertEqual(dataset.split, "train[5%:10%]")
85+
self.assertEqual(dataset.split, "train[5%:7%]")
8686
tokenized_dataset = op_manager()
8787
self.assertIn("input_ids", tokenized_dataset.features)
8888
self.assertIn("labels", tokenized_dataset.features)
@@ -107,7 +107,7 @@ def test_max_seq_len_clipped(self):
107107
op_manager = TextGenerationDataset.load_from_registry(
108108
self.data_args.dataset,
109109
data_args=self.data_args,
110-
split="train[80%:]",
110+
split="train[95%:]",
111111
processor=self.tiny_llama_tokenizer,
112112
)
113113

@@ -136,15 +136,15 @@ def test_dataset_kwargs_and_percentages(self):
136136
c4_manager_a = TextGenerationDataset.load_from_registry(
137137
self.data_args.dataset,
138138
data_args=self.data_args,
139-
split="train[5%:10%]",
139+
split="train[5%:6%]",
140140
processor=self.tiny_llama_tokenizer,
141141
)
142142
raw_dataset_a = c4_manager_a.load_dataset()
143143

144144
c4_manager_b = TextGenerationDataset.load_from_registry(
145145
self.data_args.dataset,
146146
data_args=self.data_args,
147-
split="train[5%:15%]",
147+
split="train[6%:8%]",
148148
processor=self.tiny_llama_tokenizer,
149149
)
150150
raw_dataset_b = c4_manager_b.load_dataset()
@@ -162,7 +162,7 @@ def prepare_fixture(self, tiny_llama_tokenizer):
162162
[
163163
["ptb", "penn_treebank", "train[:5%]", False],
164164
["gsm8k", "main", "train[:5%]", True],
165-
["ultrachat_200k", "default", "train_sft[:2%]", False],
165+
["ultrachat_200k", "default", "train_sft[:1%]", False],
166166
]
167167
)
168168
def test_datasets(self, dataset_key, dataset_config, split, do_concat):
@@ -271,9 +271,7 @@ class TestSplitLoading(unittest.TestCase):
271271
def prepare_fixture(self, tiny_llama_tokenizer):
272272
self.tiny_llama_tokenizer = tiny_llama_tokenizer
273273

274-
@parameterized.expand(
275-
[["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]]
276-
)
274+
@parameterized.expand([["train[95%:]"], [{"train": "train[:5%]"}]])
277275
def test_split_loading(self, split_def):
278276
data_args = DatasetArguments(
279277
dataset="open_platypus",
@@ -302,7 +300,7 @@ class TestTokenizationDataset(unittest.TestCase):
302300
def prepare_fixture(self, tiny_llama_tokenizer):
303301
self.tiny_llama_tokenizer = tiny_llama_tokenizer
304302
dataset = load_dataset("garage-bAInd/Open-Platypus")["train"]
305-
self.num_calib_samples = 256
303+
self.num_calib_samples = 64
306304
self.max_seq_len = 512
307305
self.dataset = dataset.shuffle(seed=42).select(range(self.num_calib_samples))
308306

tests/llmcompressor/transformers/finetune/finetune_oneshot_configs/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ model: "Xenova/llama2.c-stories15M"
44
dataset: wikitext
55
dataset_config_name: "wikitext-2-raw-v1"
66
recipe: "tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml"
7-
num_train_epochs: 1
7+
num_train_epochs: 0.25
88
concat_txt: False

tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class TestOneshotAndFinetune(unittest.TestCase):
1919
def _test_oneshot_and_finetune(self):
2020
from llmcompressor.transformers import apply
2121

22-
splits = {"train": "train[:30%]", "calibration": "train[30%:40%]"}
22+
splits = {"train": "train[:5%]", "calibration": "train[5%:10%]"}
2323
if self.dataset == "ultrachat-200k":
24-
splits = {"train": "train_gen[:30%]", "calibration": "train_gen[30%:40%]"}
24+
splits = {"train": "train_gen[:5%]", "calibration": "train_gen[5%:10%]"}
2525

2626
apply(
2727
model=self.model,
@@ -30,6 +30,7 @@ def _test_oneshot_and_finetune(self):
3030
output_dir=self.output,
3131
recipe=self.recipe,
3232
num_train_epochs=self.num_train_epochs,
33+
num_calibration_samples=64,
3334
concatenate_data=self.concat_txt,
3435
splits=splits,
3536
oneshot_device=self.device,

tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_oneshot_sparsification_then_finetune(self):
2727
concatenate_data = False
2828
num_calibration_samples = 64
2929
output_dir = self.output / "oneshot_out"
30-
splits = {"calibration": "train[:10%]"}
30+
splits = {"calibration": "train[:5%]"}
3131

3232
with create_session():
3333
oneshot(
@@ -56,20 +56,18 @@ def test_oneshot_sparsification_then_finetune(self):
5656
dataset = "open_platypus"
5757
concatenate_data = False
5858
output_dir = self.output / "finetune_out"
59-
splits = "train[:50%]"
60-
max_steps = 25
59+
splits = "train[5%:7%]"
6160

6261
with create_session():
6362
train(
6463
model=model,
6564
distill_teacher=distill_teacher,
6665
dataset=dataset,
6766
output_dir=output_dir,
68-
num_calibration_samples=num_calibration_samples,
67+
num_train_epochs=0.05,
6968
recipe=recipe_str,
7069
concatenate_data=concatenate_data,
7170
splits=splits,
72-
max_steps=max_steps,
7371
)
7472

7573
# test reloading checkpoint and final model
@@ -85,11 +83,10 @@ def test_oneshot_sparsification_then_finetune(self):
8583
distill_teacher=distill_teacher,
8684
dataset=dataset,
8785
output_dir=output_dir,
88-
num_calibration_samples=num_calibration_samples,
86+
num_train_epochs=0.05,
8987
recipe=recipe_str,
9088
concatenate_data=concatenate_data,
9189
splits=splits,
92-
max_steps=max_steps,
9390
resume_from_checkpoint=True, # use last checkpoint
9491
)
9592

@@ -106,7 +103,7 @@ def test_oneshot_quantization_then_finetune(self):
106103
concatenate_data = False
107104
num_calibration_samples = 64
108105
output_dir = self.output / "oneshot_out"
109-
splits = {"calibration": "train[:10%]"}
106+
splits = {"calibration": "train[:5%]"}
110107

111108
with create_session():
112109
oneshot(
@@ -130,17 +127,17 @@ def test_oneshot_quantization_then_finetune(self):
130127
dataset = "open_platypus"
131128
concatenate_data = False
132129
output_dir = self.output / "finetune_out"
133-
splits = {"calibration": "train[:10%]", "train": "train[:10%]"}
130+
splits = {"calibration": "train[:5%]", "train": "train[5%:7%]"}
134131

135132
with create_session():
136133
train(
137134
model=model,
138135
dataset=dataset,
139136
output_dir=output_dir,
140-
num_calibration_samples=num_calibration_samples,
141137
recipe=recipe,
142138
concatenate_data=concatenate_data,
143139
splits=splits,
140+
num_train_epochs=0.05,
144141
)
145142

146143
# test reloading checkpoint and final model
@@ -152,10 +149,10 @@ def test_oneshot_quantization_then_finetune(self):
152149
model=model,
153150
dataset=dataset,
154151
output_dir=output_dir,
155-
num_calibration_samples=num_calibration_samples,
156152
recipe=recipe,
157153
concatenate_data=concatenate_data,
158154
splits=splits,
155+
num_train_epochs=0.05,
159156
resume_from_checkpoint=True, # use last checkpoint
160157
)
161158

0 commit comments

Comments
 (0)