Skip to content

Commit 4a58bb1

Browse files
committed
clean-up
1 parent a7bf319 commit 4a58bb1

File tree

2 files changed

+97
-45
lines changed

2 files changed

+97
-45
lines changed

examples/weight_transform.py

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,62 @@
1+
import torch
2+
from compressed_tensors.quantization import (
3+
QuantizationArgs,
4+
QuantizationScheme,
5+
QuantizationStrategy,
6+
)
17
from compressed_tensors.transforms import Hadamard, RandomHadamard, Transforms
28
from compressed_tensors.transforms.transform_args import (
39
ModuleTarget,
410
TransformationArgs,
511
)
612
from compressed_tensors.transforms.transform_config import TransformationConfig
7-
from compressed_tensors.transforms.transform_data import TransformData
813
from compressed_tensors.transforms.transform_scheme import TransformationScheme
914
from transformers import AutoModelForCausalLM, AutoTokenizer
10-
import torch
1115

12-
ignore = ["re:*.mlp.down_proj$"]
13-
module_targets = [ModuleTarget.WEIGHTS]
16+
from llmcompressor import oneshot
17+
from llmcompressor.modifiers.quantization import QuantizationModifier
18+
19+
# U(W)V.T
20+
21+
ignore = ["re:.*.mlp.down_proj$"]
22+
module_targets = [ModuleTarget.WEIGHT.value]
1423

15-
# Start with a processed
16-
targets = ["Linear"] # 2048 * 2048
24+
# Start with a processed
25+
targets = ["Linear"] # 2048 * 2048
1726
v_linear_args = TransformationArgs(
18-
targets=targets, module_targets=module_targets, ignore=ignore, call_args={"transpose": True, "first": False}
27+
targets=targets,
28+
module_targets=module_targets,
29+
ignore=ignore,
30+
call_args={"transpose": True, "first": False},
1931
)
2032

21-
targets = ["re:*.mlp.down_proj$"] # 5632 * 5632
33+
targets = ["re:.*.mlp.down_proj$"] # 8192 * 8192
2234
v_down_proj = TransformationArgs(
23-
targets=targets, module_targets=module_targets, call_args={"transpose": True, "first": False}
35+
targets=targets,
36+
module_targets=module_targets,
37+
call_args={"transpose": True, "first": False},
2438
)
2539

26-
targets = ["re:*.attn.q_proj$", "re:*.attn.o_proj$", "re:*.mlp.down_proj$"] # 2048 * 2048
40+
targets = [
41+
"re:.*.attn.q_proj$",
42+
"re:.*.attn.o_proj$",
43+
"re:.*.mlp.down_proj$",
44+
] # 2048 * 2048
2745
u_q_o_down_proj = TransformationArgs(
28-
targets=targets, module_targets=module_targets,
46+
targets=targets,
47+
module_targets=module_targets,
2948
)
3049

31-
targets = ["re:*.attn.gate_proj$", "re:*.mlp.up_proj$"] # 5632 * 5632
50+
targets = ["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"] # 8192 * 8192
3251
u_gate_up_proj = TransformationArgs(
33-
targets=targets, module_targets=module_targets,
52+
targets=targets,
53+
module_targets=module_targets,
3454
)
3555

36-
targets = ["re:*.attn.k_proj$", "re:*.attn.v_proj$"] # 256 * 256
56+
targets = ["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"] # 512 * 512
3757
u_k_v_proj = TransformationArgs(
38-
targets=targets, module_targets=module_targets,
58+
targets=targets,
59+
module_targets=module_targets,
3960
)
4061

4162

@@ -51,7 +72,7 @@
5172
v_scheme_down_proj = TransformationScheme(
5273
transform_type="random-hadamard",
5374
groups=[v_down_proj],
54-
transform_creation_args={"size": 5632},
75+
transform_creation_args={"size": 8192},
5576
)
5677

5778
# We could combine multiple args to the same scheme but then would make it more difficult to consolidate order of transforms
@@ -64,35 +85,65 @@
6485
u_scheme_gate_up_proj = TransformationScheme(
6586
transform_type="random-hadamard",
6687
groups=[u_gate_up_proj],
67-
transform_creation_args={"size": 5632},
88+
transform_creation_args={"size": 8192},
6889
)
6990

7091
u_scheme_k_v_proj = TransformationScheme(
7192
transform_type="random-hadamard",
7293
groups=[u_k_v_proj],
73-
transform_creation_args={"size": 256},
94+
transform_creation_args={"size": 512},
7495
)
7596

7697
# QuIP Recipe with weight only quantization
7798
config = TransformationConfig(
7899
transform_groups={
79100
"u_transform_q_o_down_proj": u_scheme_q_o_down_proj,
80-
"u_transform_gate_up_proj": u_scheme_gate_up_proj,
81101
"u_transform_k_v_proj": u_scheme_k_v_proj,
102+
"u_transform_gate_up_proj": u_scheme_gate_up_proj,
82103
"v_transform_linear": v_scheme,
83-
"v_transform_down_proj": v_scheme_down_proj
104+
"v_transform_down_proj": v_scheme_down_proj,
84105
}
85106
)
86107

87-
#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
88-
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
108+
recipe = QuantizationModifier(
109+
targets="Linear",
110+
ignore=["lm_head"],
111+
config_groups={
112+
"group_0": QuantizationScheme(
113+
targets=["Linear"],
114+
weights=QuantizationArgs(
115+
num_bits=4,
116+
symmetric=True,
117+
strategy=QuantizationStrategy.GROUP,
118+
group_size=128,
119+
),
120+
)
121+
},
122+
transforms_config=config,
123+
)
124+
125+
MODEL_ID = "meta-llama/Llama-3.2-1B"
89126

90127
model = AutoModelForCausalLM.from_pretrained(
91-
MODEL_ID,
92-
device_map="auto",
93-
torch_dtype="auto",
128+
MODEL_ID, device_map="auto", torch_dtype="auto"
94129
)
95130
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
131+
132+
oneshot(model=model, recipe=recipe)
133+
134+
print("\n\n")
135+
print("========== SAMPLE GENERATION ==============")
136+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
137+
output = model.generate(input_ids, max_new_tokens=100)
138+
print(tokenizer.decode(output[0]))
139+
print("==========================================\n\n")
140+
141+
# Save to disk compressed.
142+
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-Transforms"
143+
model.save_pretrained(SAVE_DIR)
144+
tokenizer.save_pretrained(SAVE_DIR)
145+
146+
"""
96147
x = model.model.layers[0]
97148
attn = x.self_attn
98149
mlp = x.mlp
@@ -104,16 +155,26 @@
104155
attn.o_proj,
105156
mlp.gate_proj,
106157
mlp.down_proj,
107-
mlp.up_proj
158+
mlp.up_proj,
108159
]
109160
110-
for layer in layers:
161+
from compressed_tensors.transforms.hadamard_utils import (
162+
deterministic_hadamard_matrix,
163+
random_hadamard_matrix,
164+
)
111165
166+
for layer in layers:
112167
current_weight = layer.weight
168+
original_weight = current_weight.data.clone()
113169
(n, m) = current_weight.shape
114-
U = torch.eye(n).to("cuda").to(torch.bfloat16)
115-
V = torch.eye(m).to("cuda").to(torch.bfloat16)
116-
print(n, layer)
170+
171+
U = torch.Tensor(random_hadamard_matrix(n)).to("cuda").to(torch.float32)
172+
V = torch.Tensor(random_hadamard_matrix(m)).to("cuda").to(torch.float32)
117173
118174
output = torch.matmul(U, current_weight)
119175
output = torch.matmul(output, V.T)
176+
177+
# apply untransform
178+
x = torch.matmul(U.T, torch.matmul(output, V))
179+
print(torch.max(abs(x - original_weight)))
180+
"""

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
55
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
66
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
7+
from compressed_tensors.transforms.apply import apply_transforms_to_parameter
78
from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data
89
from loguru import logger
910
from torch.nn import Module
@@ -123,25 +124,15 @@ def update_weight_zp_scale(module: Module):
123124

124125
transform_data = getattr(module, "transform_data", None)
125126
if transform_data is not None:
126-
# order that the transforms were added to match the order they should be applied
127127
untransformed_weight = module.weight.data.clone()
128-
for transform_name, transform_values in transform_data.data.items():
129-
transform = getattr(module, transform_name)
130-
apply = transform_values.get("apply")
131-
call_args = transform_values.get("call_args")
132-
if call_args:
133-
transformed_weight = apply(
134-
input_tensor=module.weight, transform=transform, **call_args
135-
)
136-
else:
137-
transformed_weight = apply(
138-
input_tensor=module.weight, transform=transform
139-
)
140-
module.weight.data.copy_(transformed_weight)
128+
apply_transforms_to_parameter(
129+
module=module,
130+
module_parameter=module.weight,
131+
transform_data=transform_data,
132+
)
141133

142134
call_observer(module=module, base_name="weight")
143135

144-
# TODO: what do we do here?
145136
if transform_data is not None:
146137
module.weight.data.copy_(untransformed_weight)
147138

0 commit comments

Comments
 (0)