Skip to content

Commit a9b2f51

Browse files
committed
r3 r4 works, but not with sdpa
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5aa3586 commit a9b2f51

File tree

4 files changed

+60
-46
lines changed

4 files changed

+60
-46
lines changed

examples/transform/spinquant_example.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
MODEL_ID,
1414
torch_dtype="auto",
1515
)
16-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, attn_implementation="eager")
1717

1818
# Select calibration dataset.
1919
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
@@ -58,8 +58,10 @@ def tokenize(sample):
5858
# * apply spinquant transforms to model in order to make quantization easier
5959
# * quantize the weights to 4 bit with GPTQ with a group size 128
6060
recipe = [
61-
SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
62-
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
61+
SpinQuantModifier(
62+
rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard"
63+
),
64+
# QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
6365
]
6466

6567
# Apply algorithms.
@@ -75,9 +77,12 @@ def tokenize(sample):
7577
print("\n\n")
7678
print("========== SAMPLE GENERATION ==============")
7779
dispatch_for_generation(model)
78-
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
79-
output = model.generate(input_ids, max_new_tokens=100)
80-
print(tokenizer.decode(output[0]))
80+
from llmcompressor.utils import calibration_forward_context
81+
82+
with calibration_forward_context(model):
83+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
84+
output = model.generate(input_ids, max_new_tokens=100)
85+
print(tokenizer.decode(output[0]))
8186
print("==========================================\n\n")
8287

8388
# Save to disk compressed.

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from llmcompressor.utils.helpers import getattr_chain
2222

2323
if TYPE_CHECKING:
24+
from compressed_tensors.modeling.attention import CompressedAttentionImpl
25+
2426
from llmcompressor.modifiers.utils.hooks import HooksMixin
2527

2628

@@ -213,7 +215,7 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
213215

214216

215217
def register_calibrate_attn_hooks(
216-
modifier: HooksMixin, attention_impl
218+
modifier: "HooksMixin", attention_impl: "CompressedAttentionImpl"
217219
) -> Set[RemovableHandle]:
218220
return {
219221
modifier.register_hook(

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,6 @@ def _initialize_observers(self, module: torch.nn.Module):
242242
def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
243243
hooks = set()
244244

245-
# TODO: attnq
246-
# attention_impl = get_compressed_attention_impl()
247-
# hooks |= register_calibrate_attn_hooks(self, attention_impl)
248-
249245
for module in model.modules():
250246
if not hasattr(module, "quantization_scheme"):
251247
continue
@@ -264,6 +260,11 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
264260
self.register_hook(module, calibrate_input_hook, "forward_pre")
265261
)
266262

263+
# TODO: attnq
264+
# if is_attention:
265+
# attention_impl = CompressedAttentionImpl.from_module(module)
266+
# hooks |= register_calibrate_attn_hooks(self, attention_impl)
267+
267268
# kv_cache activations. Within `apply_quantization_config`, the config is
268269
# modified to use attention output quantization if a kv_cache_scheme exists
269270
if is_attention and output:

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
109109
config_groups["R2"] = self._create_r2_scheme(state.model)
110110

111111
if SpinquantRotation.R3 in self.rotations:
112-
config_groups["R3"] = self._create_r3_scheme()
112+
config_groups["R3"] = self._create_r3_scheme(state.model)
113113

114114
if SpinquantRotation.R4 in self.rotations:
115115
config_groups["R4"] = self._create_r4_scheme()
@@ -214,41 +214,47 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
214214
],
215215
)
216216

217-
def _create_r3_scheme(self) -> TransformScheme:
218-
return (
219-
TransformScheme(
220-
type=self.transform_type,
221-
randomize=self.randomize,
222-
requires_grad=self.learnable,
223-
apply=[
224-
TransformArgs(
225-
targets=[self.mappings.attn],
226-
location="attn_q",
227-
),
228-
TransformArgs(
229-
targets=[self.mappings.attn],
230-
location="attn_k",
231-
),
232-
],
233-
),
217+
def _create_r3_scheme(self, model: PreTrainedModel) -> TransformScheme:
218+
config = model.config
219+
220+
if hasattr(config, "head_dim"):
221+
head_dim = config.head_dim
222+
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
223+
head_dim = config.hidden_size // config.num_attention_heads
224+
else:
225+
raise NotImplementedError()
226+
227+
return TransformScheme(
228+
type=self.transform_type,
229+
randomize=self.randomize,
230+
requires_grad=self.learnable,
231+
head_dim=head_dim,
232+
apply=[
233+
TransformArgs(
234+
targets=[self.mappings.attn],
235+
location="attn_q",
236+
),
237+
TransformArgs(
238+
targets=[self.mappings.attn],
239+
location="attn_k",
240+
),
241+
],
234242
)
235243

236244
def _create_r4_scheme(self) -> TransformScheme:
237-
return (
238-
TransformScheme(
239-
type=self.transform_type,
240-
randomize=self.randomize,
241-
requires_grad=self.learnable,
242-
apply=[
243-
TransformArgs(
244-
targets=[*self.mappings.mlp_out],
245-
location="input",
246-
),
247-
TransformArgs(
248-
targets=[*self.mappings.mlp_out],
249-
location="weight_input",
250-
inverse=True,
251-
),
252-
],
253-
),
245+
return TransformScheme(
246+
type=self.transform_type,
247+
randomize=self.randomize,
248+
requires_grad=self.learnable,
249+
apply=[
250+
TransformArgs(
251+
targets=[*self.mappings.mlp_out],
252+
location="input",
253+
),
254+
TransformArgs(
255+
targets=[*self.mappings.mlp_out],
256+
location="weight_input",
257+
inverse=True,
258+
),
259+
],
254260
)

0 commit comments

Comments
 (0)