|
2 | 2 | from datasets import load_dataset |
3 | 3 | from llmcompressor import oneshot |
4 | 4 | from llmcompressor.modifiers.quantization import QuantizationModifier |
| 5 | +import torch |
5 | 6 |
|
6 | 7 | MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27" |
7 | 8 |
|
|
25 | 26 | ], |
26 | 27 | ) |
27 | 28 |
|
28 | | -DATASET_ID = "HuggingFaceH4/ultrachat_200k" |
29 | | -DATASET_SPLIT = "train_sft" |
30 | | - |
31 | | -# Select number of samples |
| 29 | +DATASET_ID = "neuralmagic/calibration" |
32 | 30 | NUM_CALIBRATION_SAMPLES = 20 |
33 | | -MAX_SEQUENCE_LENGTH = 2048 |
| 31 | +MAX_SEQUENCE_LENGTH = 8192 |
34 | 32 |
|
35 | | -# Load dataset and preprocess. |
36 | | -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") |
37 | | -ds = ds.shuffle(seed=42) |
| 33 | +ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]") |
38 | 34 |
|
39 | 35 |
|
40 | | -def preprocess(example): |
41 | | - return { |
42 | | - "text": processor.apply_chat_template( |
43 | | - example["messages"], |
44 | | - tokenize=False, |
| 36 | +def preprocess_function(example): |
| 37 | + messgages = [] |
| 38 | + for message in example["messages"]: |
| 39 | + messgages.append( |
| 40 | + { |
| 41 | + "role": message["role"], |
| 42 | + "content": [{"type": "text", "text": message["content"]}], |
| 43 | + } |
45 | 44 | ) |
46 | | - } |
47 | | - |
48 | 45 |
|
49 | | -ds = ds.map(preprocess) |
50 | | - |
51 | | - |
52 | | -# Tokenize inputs. |
53 | | -def tokenize(sample): |
54 | | - return processor( |
55 | | - sample["text"], |
| 46 | + return processor.apply_chat_template( |
| 47 | + messgages, |
| 48 | + return_tensors="pt", |
56 | 49 | padding=False, |
57 | | - max_length=MAX_SEQUENCE_LENGTH, |
58 | 50 | truncation=True, |
| 51 | + max_length=MAX_SEQUENCE_LENGTH, |
| 52 | + tokenize=True, |
59 | 53 | add_special_tokens=False, |
| 54 | + return_dict=True, |
| 55 | + add_generation_prompt=False, |
60 | 56 | ) |
61 | 57 |
|
62 | 58 |
|
63 | | -ds = ds.map(tokenize, remove_columns=ds.column_names) |
| 59 | +ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names) |
| 60 | + |
| 61 | + |
| 62 | +def data_collator(batch): |
| 63 | + assert len(batch) == 1 |
| 64 | + return { |
| 65 | + key: ( |
| 66 | + torch.tensor(value) |
| 67 | + if key != "pixel_values" |
| 68 | + else torch.tensor(value, dtype=torch.bfloat16).squeeze(0) |
| 69 | + ) |
| 70 | + for key, value in batch[0].items() |
| 71 | + } |
| 72 | + |
64 | 73 |
|
65 | 74 |
|
66 | 75 | # Apply quantization. |
67 | 76 | oneshot(model=model, |
68 | 77 | recipe=recipe, |
69 | | - dataset=ds, |
| 78 | + dataset=ds, |
| 79 | + data_collator=data_collator, |
70 | 80 | max_seq_length=MAX_SEQUENCE_LENGTH, |
71 | 81 | num_calibration_samples=NUM_CALIBRATION_SAMPLES, |
72 | 82 | moe_calibrate_all_experts=True) |
|
0 commit comments