Skip to content

Commit 4fe45de

Browse files
authored
add deepseek example (#171)
1 parent 7358094 commit 4fe45de

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import torch
2+
from datasets import load_dataset
3+
from transformers import AutoTokenizer
4+
5+
from llmcompressor.modifiers.quantization import GPTQModifier
6+
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
7+
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map
8+
9+
# select a Mixture of Experts model for quantization
10+
MODEL_ID = "deepseek-ai/DeepSeek-V2.5"
11+
12+
# adjust based off number of desired GPUs
13+
# if not enough memory is available, some layers will automatically be offlaoded to cpu
14+
device_map = calculate_offload_device_map(
15+
MODEL_ID,
16+
reserve_for_hessians=True,
17+
num_gpus=2,
18+
torch_dtype=torch.bfloat16,
19+
trust_remote_code=True,
20+
)
21+
22+
model = SparseAutoModelForCausalLM.from_pretrained(
23+
MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16, trust_remote_code=True
24+
)
25+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26+
27+
# Select calibration dataset.
28+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
29+
DATASET_SPLIT = "train_sft"
30+
NUM_CALIBRATION_SAMPLES = 512
31+
MAX_SEQUENCE_LENGTH = 2048
32+
33+
34+
# Load dataset and preprocess.
35+
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
36+
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
37+
38+
39+
def preprocess(example):
40+
return {
41+
"text": tokenizer.apply_chat_template(
42+
example["messages"],
43+
tokenize=False,
44+
)
45+
}
46+
47+
48+
ds = ds.map(preprocess)
49+
50+
51+
# Tokenize inputs.
52+
def tokenize(sample):
53+
return tokenizer(
54+
sample["text"],
55+
padding=False,
56+
max_length=MAX_SEQUENCE_LENGTH,
57+
truncation=True,
58+
add_special_tokens=False,
59+
)
60+
61+
62+
ds = ds.map(tokenize, remove_columns=ds.column_names)
63+
64+
# define a llmcompressor recipe for W416 quantization
65+
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
66+
# list so they remain at full precision
67+
recipe = "deepseek_recipe_w4a16.yaml"
68+
69+
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16"
70+
71+
72+
oneshot(
73+
model=model,
74+
dataset=ds,
75+
recipe=recipe,
76+
max_seq_length=MAX_SEQUENCE_LENGTH,
77+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
78+
save_compressed=True,
79+
output_dir=SAVE_DIR,
80+
)
81+
82+
# Confirm generations of the quantized model look sane.
83+
print("========== SAMPLE GENERATION ==============")
84+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
85+
output = model.generate(input_ids, max_new_tokens=20)
86+
print(tokenizer.decode(output[0]))
87+
print("==========================================")
88+
89+
90+
# Run the model on vLLM
91+
try:
92+
from vllm import LLM, SamplingParams
93+
94+
vllm_installed = True
95+
except ImportError:
96+
vllm_installed = False
97+
98+
if vllm_installed:
99+
print("vLLM installed, running using vLLM")
100+
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)
101+
llm = LLM(
102+
model=SAVE_DIR,
103+
tensor_parallel_size=2,
104+
trust_remote_code=True,
105+
max_model_len=1042,
106+
dtype=torch.half,
107+
)
108+
prompts = [
109+
"The capital of France is",
110+
"The president of the US is",
111+
"My name is",
112+
]
113+
114+
outputs = llm.generate(prompts, sampling_params)
115+
print("================= vLLM GENERATION ======================")
116+
for output in outputs:
117+
assert output
118+
prompt = output.prompt
119+
generated_text = output.outputs[0].text
120+
print("PROMPT", prompt)
121+
print("GENERATED TEXT", generated_text)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
quant_stage:
2+
quant_modifiers:
3+
GPTQModifier:
4+
sequential_update: true
5+
ignore: [lm_head, "re:.*mlp.gate$"]
6+
config_groups:
7+
group_0:
8+
weights: {num_bits: 4, type: int, symmetric: true, strategy: channel, dynamic: false}
9+
targets: [Linear]

0 commit comments

Comments
 (0)