Skip to content

Commit 3c216dd

Browse files
DummyModel script
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 5bd51df commit 3c216dd

File tree

4 files changed

+154
-27
lines changed

4 files changed

+154
-27
lines changed

examples/transform/llama3_example.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from llmcompressor.utils import dispatch_for_generation
88

99
# Select model and load it.
10-
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" # "meta-llama/Meta-Llama-3-8B-Instruct"
10+
# MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
11+
# MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" # TODO hidden size 3072 causes failure when creating hadamard
12+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
1113

1214
model = AutoModelForCausalLM.from_pretrained(
1315
MODEL_ID,
@@ -62,17 +64,18 @@ def tokenize(sample):
6264
# preset_config="QUIP" output sensible, but cannot load saved
6365
# checkpoint or run evals (~4hrs to run)
6466
TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"),
65-
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
67+
# QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
6668
]
6769

6870
# Apply algorithms.
6971
oneshot(
7072
model=model,
71-
dataset=ds,
7273
recipe=recipe,
73-
pipeline="sequential",
74-
max_seq_length=MAX_SEQUENCE_LENGTH,
75-
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
74+
# dataset=ds,
75+
pipeline="datafree",
76+
# max_seq_length=MAX_SEQUENCE_LENGTH,
77+
# num_calibration_samples=NUM_CALIBRATION_SAMPLES,
78+
log_dir=None,
7679
)
7780

7881
# # Confirm generations of the quantized model look sane.
@@ -84,7 +87,7 @@ def tokenize(sample):
8487
print(tokenizer.decode(output[0]))
8588
# print("==========================================\n\n")
8689

87-
# Save to disk compressed.
88-
SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16"
89-
model.save_pretrained(SAVE_DIR, save_compressed=True)
90-
tokenizer.save_pretrained(SAVE_DIR)
90+
# # Save to disk compressed.
91+
# SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16"
92+
# model.save_pretrained(SAVE_DIR, save_compressed=True)
93+
# tokenizer.save_pretrained(SAVE_DIR)

examples/transform/spinquant_dummy.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
import torch
4+
from compressed_tensors.utils import update_parameter_data
5+
from llmcompressor import oneshot
6+
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
7+
from llmcompressor.modifiers.transform import TransformModifier
8+
from llmcompressor.utils import dispatch_for_generation
9+
from transformers.models.llama.modeling_llama import (
10+
LlamaRMSNorm,
11+
)
12+
13+
hidden_dim = intermediate_dim = 64
14+
up_dim = 128
15+
num_embeddings = 12
16+
17+
18+
# TODO remove file before merging
19+
20+
21+
class DummySelfAttn(torch.nn.Module):
22+
def __init__(self, hidden_dim, intermediate_dim):
23+
super().__init__()
24+
self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None)
25+
self.k_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None)
26+
self.v_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None)
27+
self.o_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None)
28+
self.num_heads = 1
29+
self.num_key_value_groups = 1
30+
31+
def forward(self, hidden_states):
32+
q = self.q_proj(hidden_states)
33+
k = self.k_proj(hidden_states)
34+
v = self.v_proj(hidden_states)
35+
36+
### EAGER ATTENTION
37+
attn_weights = torch.matmul(q.T, k)
38+
39+
attn_weights = torch.nn.functional.softmax(
40+
attn_weights, dim=-1, dtype=torch.float32
41+
).to(q.dtype)
42+
attn_output = torch.matmul(attn_weights, v.T)
43+
attn_output = attn_output.T.contiguous()
44+
45+
return self.o_proj(attn_output)
46+
47+
48+
class DummyMLP(torch.nn.Module):
49+
def __init__(self, hidden_dim, up_dim):
50+
super().__init__()
51+
self.up_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None)
52+
self.gate_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None)
53+
self.down_proj = torch.nn.Linear(up_dim, hidden_dim, bias=None)
54+
self.act_fn = torch.nn.SiLU()
55+
56+
def forward(self, x):
57+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
58+
59+
60+
class DummyModel(torch.nn.Module):
61+
def __init__(self, num_embeddings, hidden_dim, intermediate_dim, up_dim):
62+
super().__init__()
63+
self.embed_tokens = torch.nn.Embedding(num_embeddings, hidden_dim)
64+
self.input_layernorm = LlamaRMSNorm(hidden_dim)
65+
self.post_attention_layernorm = LlamaRMSNorm(hidden_dim)
66+
self.self_attn = DummySelfAttn(hidden_dim, intermediate_dim)
67+
self.mlp = DummyMLP(hidden_dim, up_dim)
68+
self.lm_head = torch.nn.Linear(hidden_dim, num_embeddings, bias=None)
69+
70+
def forward(self, input_ids):
71+
x = self.embed_tokens(input_ids)
72+
x = self.input_layernorm(x)
73+
x = self.self_attn(x)
74+
x = self.post_attention_layernorm(x)
75+
x = self.mlp(x)
76+
return self.lm_head(x)
77+
78+
79+
model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim)
80+
81+
# TODO Uncomment this to see norm diff > 1e-6
82+
# This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2
83+
# Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant)
84+
# https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39
85+
# update_parameter_data(
86+
# model.input_layernorm,
87+
# torch.rand(model.input_layernorm.weight.shape),
88+
# "weight",
89+
# )
90+
91+
input_ids = torch.IntTensor([1, 2, 3, 4, 5])
92+
orig_output = model(input_ids)
93+
94+
recipe = [
95+
# NOTE: preset_config="QUIP" output sensible, but cannot load saved
96+
# checkpoint or run evals (~4hrs to run)
97+
TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"),
98+
# QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
99+
]
100+
101+
oneshot(
102+
model=model,
103+
recipe=recipe,
104+
pipeline="datafree",
105+
log_dir=None,
106+
)
107+
108+
# # Confirm generations of the quantized model look the same
109+
transformed_output = model(input_ids)
110+
111+
print(f"Norm Diff {(orig_output-transformed_output).norm()}")
112+
print(f"Norm {orig_output.norm()}, {transformed_output.norm()}")

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def __init__(
125125
self.output_dir = output_dir
126126

127127
# initialize the model and processor
128-
pre_process(model_args)
128+
# TODO Remove Comment before merge, this is just needed for DummyModel
129+
# pre_process(model_args)
129130

130131
# Set instance attributes
131132
self.model = self.model_args.model

src/llmcompressor/modifiers/transform/presets/spinquant.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,43 +9,54 @@
99
type="hadamard",
1010
apply=[
1111
TransformArgs(
12-
targets=["re:.*embed_tokens$", "re:.*o_proj$", "re:.*down_proj$"],
12+
targets=[
13+
# outermost rotation
14+
"re:.*embed_tokens$",
15+
# attention rotations
16+
"re:.*o_proj$",
17+
# mlp rotations
18+
"re:.*down_proj$",
19+
],
1320
location="weight_output",
1421
),
1522
TransformArgs(
1623
targets=[
24+
# outermost rotation
25+
"lm_head",
26+
# attention rotations
1727
"re:.*q_proj$",
1828
"re:.*k_proj$",
1929
"re:.*v_proj$",
30+
# mlp rotations
2031
"re:.*up_proj$",
2132
"re:.*gate_proj$",
22-
"lm_head",
2333
],
2434
location="weight_input",
2535
inverse=True,
2636
),
2737
],
2838
),
29-
"R2": TransformScheme(
30-
type="hadamard",
31-
apply=[
32-
TransformArgs(
33-
targets=["re:.*v_proj$"],
34-
location="weight_output",
35-
),
36-
TransformArgs(
37-
targets=["re:.*o_proj$"], location="weight_input", inverse=True
38-
),
39-
],
40-
),
39+
# "R2": TransformScheme(
40+
# type="hadamard",
41+
# # TODO infer head_dim from config.json in SpinQuantModifier
42+
# head_dim=128,
43+
# apply=[
44+
# TransformArgs(targets=["re:.*v_proj$"], location="weight_output"),
45+
# TransformArgs(
46+
# targets=["re:.*o_proj$"],
47+
# location="weight_input",
48+
# inverse=True,
49+
# ),
50+
# ],
51+
# ),
4152
}
4253
)
4354

4455
# All rotations
4556
LLAMA_SPINQUANT = TransformConfig(
4657
config_groups={
47-
"R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"],
48-
"R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"],
58+
# "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"],
59+
# "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"],
4960
"R3": TransformScheme(
5061
type="hadamard",
5162
apply=[

0 commit comments

Comments
 (0)