Skip to content

Commit f87a78f

Browse files
committed
add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 447b5ad commit f87a78f

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

examples/quantization_w4a16/llama3_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from llmcompressor.transformers import oneshot
66

77
# Select model and load it.
8-
# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
9-
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
8+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
109

1110
model = AutoModelForCausalLM.from_pretrained(
1211
MODEL_ID,

src/llmcompressor/transformers/finetune/data/data_helpers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import logging
21
import os
3-
import warnings
42
from typing import Any, Callable, Dict, List, Optional
53

64
import torch
75
from datasets import Dataset, load_dataset
6+
from loguru import logger
87
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
98
from transformers.data.data_collator import (
109
DataCollatorWithPadding,
@@ -13,7 +12,6 @@
1312

1413
from llmcompressor.typing import Processor
1514

16-
LOGGER = logging.getLogger(__name__)
1715
LABELS_MASK_VALUE = -100
1816

1917
__all__ = [
@@ -56,7 +54,7 @@ def format_calibration_data(
5654
if num_calibration_samples is not None:
5755
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
5856
if safe_calibration_samples != num_calibration_samples:
59-
LOGGER.warn(
57+
logger.warning(
6058
f"Requested {num_calibration_samples} calibration samples but "
6159
f"the provided dataset only has {safe_calibration_samples}. "
6260
)
@@ -68,7 +66,7 @@ def format_calibration_data(
6866
if hasattr(tokenizer, "pad"):
6967
collate_fn = DataCollatorWithPadding(tokenizer)
7068
else:
71-
warnings.warn(
69+
logger.warning(
7270
"Could not find processor, attempting to collate with without padding "
7371
"(may fail for batch_size > 1)"
7472
)

tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import pytest
33
import torch
44
from datasets import Dataset
5+
from transformers import AutoTokenizer
56

67
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
78
from llmcompressor.transformers.finetune.data.data_helpers import (
89
format_calibration_data,
910
get_raw_dataset,
1011
make_dataset_splits,
1112
)
13+
from llmcompressor.transformers.finetune.text_generation import configure_processor
1214

1315

1416
@pytest.mark.unit
@@ -60,15 +62,42 @@ def test_separate_datasets():
6062

6163

6264
@pytest.mark.unit
63-
def test_format_calibration_data():
64-
tokenized_dataset = Dataset.from_dict(
65-
{"input_ids": torch.randint(0, 512, (8, 2048))}
65+
def test_format_calibration_data_padded_tokenized():
66+
vocab_size = 512
67+
seq_len = 2048
68+
ds_size = 16
69+
padded_tokenized_dataset = Dataset.from_dict(
70+
{"input_ids": torch.randint(0, vocab_size, (ds_size, seq_len))}
6671
)
6772

6873
calibration_dataloader = format_calibration_data(
69-
tokenized_dataset, num_calibration_samples=4, batch_size=2
74+
padded_tokenized_dataset, num_calibration_samples=8, batch_size=4
7075
)
7176

7277
batch = next(iter(calibration_dataloader))
78+
assert batch["input_ids"].size(0) == 4
7379

80+
81+
@pytest.mark.unit
82+
def test_format_calibration_data_unpaddded_tokenized():
83+
vocab_size = 512
84+
ds_size = 16
85+
unpadded_tokenized_dataset = Dataset.from_dict(
86+
{
87+
"input_ids": [
88+
torch.randint(0, vocab_size, (seq_len,)) for seq_len in range(ds_size)
89+
]
90+
}
91+
)
92+
processor = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
93+
configure_processor(processor)
94+
95+
calibration_dataloader = format_calibration_data(
96+
unpadded_tokenized_dataset,
97+
num_calibration_samples=8,
98+
batch_size=4,
99+
processor=processor,
100+
)
101+
102+
batch = next(iter(calibration_dataloader))
74103
assert batch["input_ids"].size(0) == 2

0 commit comments

Comments
 (0)