Skip to content

Commit d6eb2be

Browse files
YatimaiHDCharles
andauthored
fix: handle packed weights in granite4 to_3d_expert (W4A16 support) (#2425)
SUMMARY: Fix the W4A16 shape mismatch in to_3d_expert() reported in #2338 (first error). The original code hardcoded shapes for FP8 quantization only. The fix calculates all shapes up front (packed weights, grouped scales, packed zero points) then asserts and reshapes. This supports FP8 per-channel, FP8 block quantization, W4A16 symmetric, and W4A16 asymmetric (with packed zero_point on dim0). Companion to #2426 (FX tracing fix) and compressed-tensors #609 (3D pack/unpack). Together they resolve #2338. TEST PLAN: 4 unit tests covering all quantization configurations: - int4 symmetric (packed weights, per-channel scale) - int4 asymmetric (packed weights + packed zero_point on dim0) - fp8 block (grouped scale) - fp8 per-channel (no packing) All passing. Signed-off-by: Gilles Turpin <turpingilles15@gmail.com> Co-authored-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
1 parent 4c52213 commit d6eb2be

File tree

2 files changed

+159
-14
lines changed

2 files changed

+159
-14
lines changed

src/llmcompressor/modeling/granite4.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,65 @@ def from_3d_expert(cls, original: GraniteMoeHybridParallelExperts):
3535

3636
def to_3d_expert(self) -> None:
3737
"""Convert weights and quantization parameters from 2D to 3D shape."""
38-
dim0_mul = self.num_experts * self.output_size
39-
assert (
40-
self.weight.shape == torch.Size((dim0_mul, self.input_size))
41-
and hasattr(self, "weight_scale")
42-
and self.weight_scale.shape == torch.Size((dim0_mul, 1))
43-
), "Shape mismatch, please check."
38+
# Calculate all shapes up front
39+
packed_input_size = self.weight.shape[1]
40+
pack_factor = self.input_size // packed_input_size
4441

42+
assert hasattr(self, "weight_scale"), "weight_scale not found"
43+
grouped_output = self.weight_scale.shape[0] // self.num_experts
44+
grouped_input = self.weight_scale.shape[1]
45+
46+
expected_packed_weight_shape = torch.Size(
47+
(self.num_experts * self.output_size, packed_input_size)
48+
)
49+
final_packed_weight_shape = torch.Size(
50+
(self.num_experts, self.output_size, packed_input_size)
51+
)
52+
53+
expected_packed_weight_scale_shape = torch.Size(
54+
(self.num_experts * grouped_output, grouped_input)
55+
)
56+
final_packed_weight_scale_shape = torch.Size(
57+
(self.num_experts, grouped_output, grouped_input)
58+
)
59+
60+
# Assert shapes match expectations
61+
assert self.weight.shape == expected_packed_weight_shape, (
62+
f"weight shape {self.weight.shape} != "
63+
f"expected {expected_packed_weight_shape}"
64+
)
65+
66+
assert self.weight_scale.shape == expected_packed_weight_scale_shape, (
67+
f"weight_scale shape {self.weight_scale.shape} != "
68+
f"expected {expected_packed_weight_scale_shape}"
69+
)
70+
71+
# Reshape to 3D
4572
self.weight = torch.nn.Parameter(
46-
self.weight.view(
47-
self.num_experts, self.output_size, self.input_size
48-
).clone(),
73+
self.weight.view(final_packed_weight_shape).clone(),
4974
requires_grad=False,
5075
)
5176
self.weight_scale = torch.nn.Parameter(
52-
self.weight_scale.view(self.num_experts, self.output_size, 1).clone(),
77+
self.weight_scale.view(final_packed_weight_scale_shape).clone(),
5378
requires_grad=False,
5479
)
80+
5581
if hasattr(self, "weight_zero_point"):
56-
assert self.weight_zero_point.shape == torch.Size((dim0_mul, 1))
82+
expected_packed_zp_shape = torch.Size(
83+
(self.num_experts * grouped_output // pack_factor, grouped_input)
84+
)
85+
final_packed_zp_shape = torch.Size(
86+
(self.num_experts, grouped_output // pack_factor, grouped_input)
87+
)
88+
assert self.weight_zero_point.shape == expected_packed_zp_shape, (
89+
f"weight_zero_point shape {self.weight_zero_point.shape} != "
90+
f"expected {expected_packed_zp_shape}"
91+
)
5792
self.weight_zero_point = torch.nn.Parameter(
58-
self.weight_zero_point.view(
59-
self.num_experts, self.output_size, 1
60-
).clone(),
93+
self.weight_zero_point.view(final_packed_zp_shape).clone(),
6194
requires_grad=False,
6295
)
96+
6397
self.is_2d = False
6498

6599
def forward(self, inputs, expert_size):
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from unittest.mock import MagicMock
2+
3+
import torch
4+
5+
from llmcompressor.modeling.granite4 import GraniteMoeHybridParallelExpertsLinear
6+
7+
8+
def _make_layer(
9+
num_experts, output_size, input_size, weight_shape, scale_shape, zp_shape=None
10+
):
11+
"""Create a mock layer with the given shapes to test to_3d_expert."""
12+
layer = MagicMock(spec=GraniteMoeHybridParallelExpertsLinear)
13+
layer.num_experts = num_experts
14+
layer.output_size = output_size
15+
layer.input_size = input_size
16+
layer.weight = torch.nn.Parameter(torch.randn(weight_shape), requires_grad=False)
17+
layer.weight_scale = torch.nn.Parameter(
18+
torch.randn(scale_shape), requires_grad=False
19+
)
20+
layer.is_2d = True
21+
if zp_shape is not None:
22+
layer.weight_zero_point = torch.nn.Parameter(
23+
torch.randn(zp_shape), requires_grad=False
24+
)
25+
else:
26+
# hasattr should return False for weight_zero_point
27+
del layer.weight_zero_point
28+
return layer
29+
30+
31+
def test_to_3d_expert_int4_symmetric():
32+
"""W4A16 symmetric: packed weight, per-channel scale, no zero_point."""
33+
num_experts, output_size, input_size = 4, 64, 128
34+
pack_factor = 8 # 4-bit packing
35+
layer = _make_layer(
36+
num_experts,
37+
output_size,
38+
input_size,
39+
weight_shape=(num_experts * output_size, input_size // pack_factor),
40+
scale_shape=(num_experts * output_size, 1),
41+
)
42+
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
43+
assert layer.weight.shape == (
44+
num_experts,
45+
output_size,
46+
input_size // pack_factor,
47+
)
48+
assert layer.weight_scale.shape == (num_experts, output_size, 1)
49+
50+
51+
def test_to_3d_expert_int4_asymmetric():
52+
"""W4A16 asymmetric: packed weight + packed zero_point on dim0."""
53+
num_experts, output_size, input_size = 4, 64, 128
54+
pack_factor = 8
55+
layer = _make_layer(
56+
num_experts,
57+
output_size,
58+
input_size,
59+
weight_shape=(num_experts * output_size, input_size // pack_factor),
60+
scale_shape=(num_experts * output_size, 1),
61+
zp_shape=(num_experts * output_size // pack_factor, 1),
62+
)
63+
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
64+
assert layer.weight.shape == (
65+
num_experts,
66+
output_size,
67+
input_size // pack_factor,
68+
)
69+
assert layer.weight_scale.shape == (num_experts, output_size, 1)
70+
assert layer.weight_zero_point.shape == (
71+
num_experts,
72+
output_size // pack_factor,
73+
1,
74+
)
75+
76+
77+
def test_to_3d_expert_fp8_block():
78+
"""FP8 block quantization: grouped scale, no packing."""
79+
num_experts, output_size, input_size = 4, 64, 128
80+
group_size = 32
81+
num_row_groups = output_size # per-row
82+
num_col_groups = input_size // group_size
83+
layer = _make_layer(
84+
num_experts,
85+
output_size,
86+
input_size,
87+
weight_shape=(num_experts * output_size, input_size),
88+
scale_shape=(num_experts * num_row_groups, num_col_groups),
89+
)
90+
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
91+
assert layer.weight.shape == (num_experts, output_size, input_size)
92+
assert layer.weight_scale.shape == (
93+
num_experts,
94+
num_row_groups,
95+
num_col_groups,
96+
)
97+
98+
99+
def test_to_3d_expert_fp8_per_channel():
100+
"""FP8 per-channel: no packing, scale per row."""
101+
num_experts, output_size, input_size = 4, 64, 128
102+
layer = _make_layer(
103+
num_experts,
104+
output_size,
105+
input_size,
106+
weight_shape=(num_experts * output_size, input_size),
107+
scale_shape=(num_experts * output_size, 1),
108+
)
109+
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
110+
assert layer.weight.shape == (num_experts, output_size, input_size)
111+
assert layer.weight_scale.shape == (num_experts, output_size, 1)

0 commit comments

Comments
 (0)