Skip to content

Commit 29bf737

Browse files
committed
WIP: working
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 44dbf91 commit 29bf737

File tree

8 files changed

+30
-19
lines changed

8 files changed

+30
-19
lines changed

examples/multimodal_vision/gemma3_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def data_collator(batch):
4848
max_seq_length=MAX_SEQUENCE_LENGTH,
4949
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
5050
trust_remote_code_model=True,
51-
data_collator=data_collator,
51+
# data_collator=data_collator,
5252
)
5353

5454
# Confirm generations of the quantized model look sane.

examples/quantization_w4a16/llama3_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from llmcompressor.utils import dispatch_for_generation
77

88
# Select model and load it.
9-
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
9+
model_id = "meta-llama/Llama-3.2-3B-Instruct"
1010
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
1111
tokenizer = AutoTokenizer.from_pretrained(model_id)
1212

@@ -16,7 +16,7 @@
1616

1717
# Select number of samples. 512 samples is a good place to start.
1818
# Increasing the number of samples can improve accuracy.
19-
NUM_CALIBRATION_SAMPLES = 12
19+
NUM_CALIBRATION_SAMPLES = 512
2020
MAX_SEQUENCE_LENGTH = 2048
2121

2222
# Load dataset and preprocess.
@@ -57,10 +57,10 @@ def tokenize(sample):
5757
oneshot(
5858
model=model,
5959
dataset=ds,
60-
batch_size=12,
6160
recipe=recipe,
6261
max_seq_length=MAX_SEQUENCE_LENGTH,
6362
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
63+
batch_size=4,
6464
)
6565

6666
# Confirm generations of the quantized model look sane.

src/llmcompressor/args/dataset_arguments.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass, field
11-
from typing import Any, Callable
12-
13-
from transformers import DefaultDataCollator
11+
from typing import Callable, Optional
1412

1513

1614
@dataclass
@@ -69,8 +67,8 @@ class CustomDatasetArguments(DVCDatasetArguments):
6967
},
7068
)
7169

72-
data_collator: Callable[[Any], Any] = field(
73-
default_factory=lambda: DefaultDataCollator(),
70+
data_collator: Optional[Callable] = field(
71+
default=None,
7472
metadata={"help": "The function to used to form a batch from the dataset"},
7573
)
7674

src/llmcompressor/datasets/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from datasets import Dataset
1616
from loguru import logger
1717
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
18-
from transformers.data import default_data_collator
18+
from transformers.data import DataCollatorWithPadding
1919

2020
from llmcompressor.args import DatasetArguments
2121
from llmcompressor.transformers.finetune.data import TextGenerationDataset
@@ -115,22 +115,29 @@ def get_calibration_dataloader(
115115
)
116116

117117
calibration_dataset = datasets.get("calibration")
118+
tokenizer = getattr(processor, "tokenizer", processor)
119+
collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer)
120+
if dataset_args.batch_size > 1 and (
121+
tokenizer.pad_token is None or tokenizer.pad_token_id < 0
122+
):
123+
logger.warning("Could not find padding token. Setting PAD token to EOS token")
124+
tokenizer.pad_token = tokenizer.eos_token
118125

119126
return format_calibration_data(
120127
tokenized_dataset=calibration_dataset,
128+
collate_fn=collate_fn,
121129
num_calibration_samples=dataset_args.num_calibration_samples,
122130
batch_size=dataset_args.batch_size,
123131
do_shuffle=dataset_args.shuffle_calibration_samples,
124-
collate_fn=dataset_args.data_collator,
125132
)
126133

127134

128135
def format_calibration_data(
129136
tokenized_dataset: Dataset,
137+
collate_fn: Callable,
130138
num_calibration_samples: int | None = None,
131139
batch_size: int = 1,
132140
do_shuffle: bool = True,
133-
collate_fn: Callable = default_data_collator,
134141
) -> list[torch.Tensor]:
135142
"""
136143
Creates a dataloader out of the calibration dataset split, trimming it to

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
from datetime import datetime
1414
from pathlib import Path
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Callable, Optional
1616

1717
from loguru import logger
1818
from torch.utils.data import DataLoader
@@ -249,6 +249,8 @@ def oneshot(
249249
dataset_config_name: str | None = None,
250250
dataset_path: str | None = None,
251251
splits: str | list[str] | dict[str, str] | None = None,
252+
batch_size: int = 1,
253+
data_collator: Optional[Callable] = None,
252254
num_calibration_samples: int = 512,
253255
shuffle_calibration_samples: bool = True,
254256
max_seq_length: int = 384,

src/llmcompressor/pipelines/basic/pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
1111
from llmcompressor.pipelines.registry import CalibrationPipeline
1212
from llmcompressor.pytorch.utils.helpers import tensors_to_device
13-
from llmcompressor.utils import calibration_forward_context, dispatch_for_generation, targets_lm_head, disable_lm_head
13+
from llmcompressor.utils import (
14+
calibration_forward_context,
15+
disable_lm_head,
16+
dispatch_for_generation,
17+
targets_lm_head,
18+
)
1419

1520
if TYPE_CHECKING:
1621
from llmcompressor.args.dataset_arguments import DatasetArguments

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
DISABLE_QAC_MODIFIERS,
2020
DisableQuantization,
2121
calibration_forward_context,
22-
targets_lm_head,
2322
disable_lm_head,
23+
targets_lm_head,
2424
)
2525

2626
if TYPE_CHECKING:

src/llmcompressor/utils/helpers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections import OrderedDict
1919
from io import BytesIO
2020
from pathlib import Path
21-
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, TYPE_CHECKING
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple, Union
2222
from urllib.parse import urlparse
2323

2424
import numpy
@@ -1094,16 +1094,15 @@ def disable_lm_head(model: torch.nn.Module):
10941094

10951095

10961096
def targets_lm_head(model: PreTrainedModel, modifiers: list["Modifier"]) -> bool:
1097-
""" Returns True if the given modifiers target the lm_head """
1097+
"""Returns True if the given modifiers target the lm_head"""
10981098
from llmcompressor.transformers.compression.compressed_tensors_utils import (
1099-
targets_embeddings
1099+
targets_embeddings,
11001100
)
11011101

11021102
targets = sum(
11031103
(list(modifier.get_targets(model)) for modifier in modifiers), start=[]
11041104
)
11051105
return targets_embeddings(model, targets, check_input=True, check_output=False)
1106-
11071106

11081107

11091108
@contextlib.contextmanager

0 commit comments

Comments
 (0)