Skip to content

Commit a3a9f17

Browse files
committed
remove stuff, reduce diff
Signed-off-by: Kyle Sayers <[email protected]>
1 parent afabe5a commit a3a9f17

File tree

5 files changed

+6
-13
lines changed

5 files changed

+6
-13
lines changed

examples/quantization_w4a16/llama3_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from llmcompressor.transformers import oneshot
66

77
# Select model and load it.
8-
#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
8+
# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
99
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
1010

1111
model = AutoModelForCausalLM.from_pretrained(
@@ -23,7 +23,7 @@
2323
# Increasing the number of samples can improve accuracy.
2424
NUM_CALIBRATION_SAMPLES = 512
2525
MAX_SEQUENCE_LENGTH = 2048
26-
BATCH_SIZE = 2
26+
BATCH_SIZE = 8
2727

2828
# Load dataset and preprocess.
2929
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
254254
)
255255
if isinstance(exception, unfixable_errors):
256256
raise exception
257-
258-
raise exception
259257

260258
warnings.warn("Falling back to layer_sequential pipeline")
261259
try:

src/llmcompressor/transformers/finetune/data/data_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def format_calibration_data(
6363
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))
6464

6565
# collate data
66-
breakpoint()
6766
if collate_fn is None:
6867
tokenizer = getattr(processor, "tokenizer", processor)
6968
if hasattr(tokenizer, "pad"):

src/llmcompressor/transformers/finetune/runner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ def __init__(
4949
data_args: "DataTrainingArguments",
5050
model_args: "ModelArguments",
5151
training_args: "TrainingArguments",
52-
processor: Processor,
5352
):
5453
self._data_args = data_args
5554
self._model_args = model_args
5655
self._training_args = training_args
5756

5857
self.datasets = {}
5958
self.trainer = None
60-
self.processor = processor
59+
self.processor = None
6160
self.parent_output_dir = self._training_args.output_dir
6261
self._output_dir = self._training_args.output_dir
6362

@@ -69,7 +68,6 @@ def populate_datasets(self, processor: Processor, add_labels: bool = True):
6968
:param processor: processor or tokenizer to use for dataset tokenization
7069
:param add_labels: if True, add labels column to dataset splits
7170
"""
72-
# TODO: remove `processor` arg in favor of self.processor
7371
if self._data_args.dataset is None:
7472
logger.info(
7573
"Running oneshot without calibration data. This is expected for "
@@ -111,7 +109,7 @@ def _get_split_name(inp_str):
111109
registry_id,
112110
data_args=self._data_args,
113111
split=split_str,
114-
processor=self.processor,
112+
processor=processor,
115113
)
116114
tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels)
117115

src/llmcompressor/transformers/finetune/text_generation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,10 @@ def main(
394394

395395
# Load datasets
396396
stage_runner = StageRunner(
397-
model_args=model_args, data_args=data_args, training_args=training_args, processor=processor
397+
model_args=model_args, data_args=data_args, training_args=training_args
398398
)
399399
add_labels = training_args.do_train or training_args.run_stages
400-
stage_runner.populate_datasets(
401-
processor=processor, add_labels=add_labels
402-
)
400+
stage_runner.populate_datasets(processor=processor, add_labels=add_labels)
403401
train_dataset = stage_runner.get_dataset_split("train")
404402
eval_dataset = stage_runner.get_dataset_split("validation")
405403
calib_dataset = stage_runner.get_dataset_split("calibration")

0 commit comments

Comments
 (0)