Skip to content

Commit 9b8dda6

Browse files
authored
Merge branch 'main' into ddp-v3-fix-imports
2 parents 2140a21 + d6eb2be commit 9b8dda6

File tree

2 files changed

+159
-14
lines changed

2 files changed

+159
-14
lines changed

src/llmcompressor/modeling/granite4.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,65 @@ def from_3d_expert(cls, original: GraniteMoeHybridParallelExperts):
3535

3636
def to_3d_expert(self) -> None:
3737
"""Convert weights and quantization parameters from 2D to 3D shape."""
38-
dim0_mul = self.num_experts * self.output_size
39-
assert (
40-
self.weight.shape == torch.Size((dim0_mul, self.input_size))
41-
and hasattr(self, "weight_scale")
42-
and self.weight_scale.shape == torch.Size((dim0_mul, 1))
43-
), "Shape mismatch, please check."
38+
# Calculate all shapes up front
39+
packed_input_size = self.weight.shape[1]
40+
pack_factor = self.input_size // packed_input_size
4441

42+
assert hasattr(self, "weight_scale"), "weight_scale not found"
43+
grouped_output = self.weight_scale.shape[0] // self.num_experts
44+
grouped_input = self.weight_scale.shape[1]
45+
46+
expected_packed_weight_shape = torch.Size(
47+
(self.num_experts * self.output_size, packed_input_size)
48+
)
49+
final_packed_weight_shape = torch.Size(
50+
(self.num_experts, self.output_size, packed_input_size)
51+
)
52+
53+
expected_packed_weight_scale_shape = torch.Size(
54+
(self.num_experts * grouped_output, grouped_input)
55+
)
56+
final_packed_weight_scale_shape = torch.Size(
57+
(self.num_experts, grouped_output, grouped_input)
58+
)
59+
60+
# Assert shapes match expectations
61+
assert self.weight.shape == expected_packed_weight_shape, (
62+
f"weight shape {self.weight.shape} != "
63+
f"expected {expected_packed_weight_shape}"
64+
)
65+
66+
assert self.weight_scale.shape == expected_packed_weight_scale_shape, (
67+
f"weight_scale shape {self.weight_scale.shape} != "
68+
f"expected {expected_packed_weight_scale_shape}"
69+
)
70+
71+
# Reshape to 3D
4572
self.weight = torch.nn.Parameter(
46-
self.weight.view(
47-
self.num_experts, self.output_size, self.input_size
48-
).clone(),
73+
self.weight.view(final_packed_weight_shape).clone(),
4974
requires_grad=False,
5075
)
5176
self.weight_scale = torch.nn.Parameter(
52-
self.weight_scale.view(self.num_experts, self.output_size, 1).clone(),
77+
self.weight_scale.view(final_packed_weight_scale_shape).clone(),
5378
requires_grad=False,
5479
)
80+
5581
if hasattr(self, "weight_zero_point"):
56-
assert self.weight_zero_point.shape == torch.Size((dim0_mul, 1))
82+
expected_packed_zp_shape = torch.Size(
83+
(self.num_experts * grouped_output // pack_factor, grouped_input)
84+
)
85+
final_packed_zp_shape = torch.Size(
86+
(self.num_experts, grouped_output // pack_factor, grouped_input)
87+
)
88+
assert self.weight_zero_point.shape == expected_packed_zp_shape, (
89+
f"weight_zero_point shape {self.weight_zero_point.shape} != "
90+
f"expected {expected_packed_zp_shape}"
91+
)
5792
self.weight_zero_point = torch.nn.Parameter(
58-
self.weight_zero_point.view(
59-
self.num_experts, self.output_size, 1
60-
).clone(),
93+
self.weight_zero_point.view(final_packed_zp_shape).clone(),
6194
requires_grad=False,
6295
)
96+
6397
self.is_2d = False
6498

6599
def forward(self, inputs, expert_size):
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from unittest.mock import MagicMock
2+
3+
import torch
4+
5+
from llmcompressor.modeling.granite4 import GraniteMoeHybridParallelExpertsLinear
6+
7+
8+
def _make_layer(
9+
num_experts, output_size, input_size, weight_shape, scale_shape, zp_shape=None
10+
):
11+
"""Create a mock layer with the given shapes to test to_3d_expert."""
12+
layer = MagicMock(spec=GraniteMoeHybridParallelExpertsLinear)
13+
layer.num_experts = num_experts
14+
layer.output_size = output_size
15+
layer.input_size = input_size
16+
layer.weight = torch.nn.Parameter(torch.randn(weight_shape), requires_grad=False)
17+
layer.weight_scale = torch.nn.Parameter(
18+
torch.randn(scale_shape), requires_grad=False
19+
)
20+
layer.is_2d = True
21+
if zp_shape is not None:
22+
layer.weight_zero_point = torch.nn.Parameter(
23+
torch.randn(zp_shape), requires_grad=False
24+
)
25+
else:
26+
# hasattr should return False for weight_zero_point
27+
del layer.weight_zero_point
28+
return layer
29+
30+
31+
def test_to_3d_expert_int4_symmetric():
32+
"""W4A16 symmetric: packed weight, per-channel scale, no zero_point."""
33+
num_experts, output_size, input_size = 4, 64, 128
34+
pack_factor = 8 # 4-bit packing
35+
layer = _make_layer(
36+
num_experts,
37+
output_size,
38+
input_size,
39+
weight_shape=(num_experts * output_size, input_size // pack_factor),
40+
scale_shape=(num_experts * output_size, 1),
41+
)
42+
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
43+
assert layer.weight.shape == (
44+
num_experts,
45+
output_size,
46+
input_size // pack_factor,
47+
)
48+
assert layer.weight_scale.shape == (num_experts, output_size, 1)
49+
50+
51+
def test_to_3d_expert_int4_asymmetric():
52+
"""W4A16 asymmetric: packed weight + packed zero_point on dim0."""
53+
num_experts, output_size, input_size = 4, 64, 128
54+
pack_factor = 8
55+
layer = _make_layer(
56+
num_experts,
57+
output_size,
58+
input_size,
59+
weight_shape=(num_experts * output_size, input_size // pack_factor),
60+
scale_shape=(num_experts * output_size, 1),
61+
zp_shape=(num_experts * output_size // pack_factor, 1),
62+
)
63+
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
64+
assert layer.weight.shape == (
65+
num_experts,
66+
output_size,
67+
input_size // pack_factor,
68+
)
69+
assert layer.weight_scale.shape == (num_experts, output_size, 1)
70+
assert layer.weight_zero_point.shape == (
71+
num_experts,
72+
output_size // pack_factor,
73+
1,
74+
)
75+
76+
77+
def test_to_3d_expert_fp8_block():
78+
"""FP8 block quantization: grouped scale, no packing."""
79+
num_experts, output_size, input_size = 4, 64, 128
80+
group_size = 32
81+
num_row_groups = output_size # per-row
82+
num_col_groups = input_size // group_size
83+
layer = _make_layer(
84+
num_experts,
85+
output_size,
86+
input_size,
87+
weight_shape=(num_experts * output_size, input_size),
88+
scale_shape=(num_experts * num_row_groups, num_col_groups),
89+
)
90+
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
91+
assert layer.weight.shape == (num_experts, output_size, input_size)
92+
assert layer.weight_scale.shape == (
93+
num_experts,
94+
num_row_groups,
95+
num_col_groups,
96+
)
97+
98+
99+
def test_to_3d_expert_fp8_per_channel():
100+
"""FP8 per-channel: no packing, scale per row."""
101+
num_experts, output_size, input_size = 4, 64, 128
102+
layer = _make_layer(
103+
num_experts,
104+
output_size,
105+
input_size,
106+
weight_shape=(num_experts * output_size, input_size),
107+
scale_shape=(num_experts * output_size, 1),
108+
)
109+
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
110+
assert layer.weight.shape == (num_experts, output_size, input_size)
111+
assert layer.weight_scale.shape == (num_experts, output_size, 1)

0 commit comments

Comments
 (0)