Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions src/llmcompressor/modeling/granite4.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,65 @@ def from_3d_expert(cls, original: GraniteMoeHybridParallelExperts):

def to_3d_expert(self) -> None:
"""Convert weights and quantization parameters from 2D to 3D shape."""
dim0_mul = self.num_experts * self.output_size
assert (
self.weight.shape == torch.Size((dim0_mul, self.input_size))
and hasattr(self, "weight_scale")
and self.weight_scale.shape == torch.Size((dim0_mul, 1))
), "Shape mismatch, please check."
# Calculate all shapes up front
packed_input_size = self.weight.shape[1]
pack_factor = self.input_size // packed_input_size

assert hasattr(self, "weight_scale"), "weight_scale not found"
grouped_output = self.weight_scale.shape[0] // self.num_experts
grouped_input = self.weight_scale.shape[1]

expected_packed_weight_shape = torch.Size(
(self.num_experts * self.output_size, packed_input_size)
)
final_packed_weight_shape = torch.Size(
(self.num_experts, self.output_size, packed_input_size)
)

expected_packed_weight_scale_shape = torch.Size(
(self.num_experts * grouped_output, grouped_input)
)
final_packed_weight_scale_shape = torch.Size(
(self.num_experts, grouped_output, grouped_input)
)

# Assert shapes match expectations
assert self.weight.shape == expected_packed_weight_shape, (
f"weight shape {self.weight.shape} != "
f"expected {expected_packed_weight_shape}"
)

assert self.weight_scale.shape == expected_packed_weight_scale_shape, (
f"weight_scale shape {self.weight_scale.shape} != "
f"expected {expected_packed_weight_scale_shape}"
)

# Reshape to 3D
self.weight = torch.nn.Parameter(
self.weight.view(
self.num_experts, self.output_size, self.input_size
).clone(),
self.weight.view(final_packed_weight_shape).clone(),
requires_grad=False,
)
self.weight_scale = torch.nn.Parameter(
self.weight_scale.view(self.num_experts, self.output_size, 1).clone(),
self.weight_scale.view(final_packed_weight_scale_shape).clone(),
requires_grad=False,
)

if hasattr(self, "weight_zero_point"):
assert self.weight_zero_point.shape == torch.Size((dim0_mul, 1))
expected_packed_zp_shape = torch.Size(
(self.num_experts * grouped_output // pack_factor, grouped_input)
)
final_packed_zp_shape = torch.Size(
(self.num_experts, grouped_output // pack_factor, grouped_input)
)
assert self.weight_zero_point.shape == expected_packed_zp_shape, (
f"weight_zero_point shape {self.weight_zero_point.shape} != "
f"expected {expected_packed_zp_shape}"
)
self.weight_zero_point = torch.nn.Parameter(
self.weight_zero_point.view(
self.num_experts, self.output_size, 1
).clone(),
self.weight_zero_point.view(final_packed_zp_shape).clone(),
requires_grad=False,
)

self.is_2d = False

def forward(self, inputs, expert_size):
Expand Down
111 changes: 111 additions & 0 deletions tests/llmcompressor/modeling/test_granite4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from unittest.mock import MagicMock

import torch

from llmcompressor.modeling.granite4 import GraniteMoeHybridParallelExpertsLinear


def _make_layer(
num_experts, output_size, input_size, weight_shape, scale_shape, zp_shape=None
):
"""Create a mock layer with the given shapes to test to_3d_expert."""
layer = MagicMock(spec=GraniteMoeHybridParallelExpertsLinear)
layer.num_experts = num_experts
layer.output_size = output_size
layer.input_size = input_size
layer.weight = torch.nn.Parameter(torch.randn(weight_shape), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
torch.randn(scale_shape), requires_grad=False
)
layer.is_2d = True
if zp_shape is not None:
layer.weight_zero_point = torch.nn.Parameter(
torch.randn(zp_shape), requires_grad=False
)
else:
# hasattr should return False for weight_zero_point
del layer.weight_zero_point
return layer


def test_to_3d_expert_int4_symmetric():
"""W4A16 symmetric: packed weight, per-channel scale, no zero_point."""
num_experts, output_size, input_size = 4, 64, 128
pack_factor = 8 # 4-bit packing
layer = _make_layer(
num_experts,
output_size,
input_size,
weight_shape=(num_experts * output_size, input_size // pack_factor),
scale_shape=(num_experts * output_size, 1),
)
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
assert layer.weight.shape == (
num_experts,
output_size,
input_size // pack_factor,
)
assert layer.weight_scale.shape == (num_experts, output_size, 1)


def test_to_3d_expert_int4_asymmetric():
"""W4A16 asymmetric: packed weight + packed zero_point on dim0."""
num_experts, output_size, input_size = 4, 64, 128
pack_factor = 8
layer = _make_layer(
num_experts,
output_size,
input_size,
weight_shape=(num_experts * output_size, input_size // pack_factor),
scale_shape=(num_experts * output_size, 1),
zp_shape=(num_experts * output_size // pack_factor, 1),
)
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
assert layer.weight.shape == (
num_experts,
output_size,
input_size // pack_factor,
)
assert layer.weight_scale.shape == (num_experts, output_size, 1)
assert layer.weight_zero_point.shape == (
num_experts,
output_size // pack_factor,
1,
)


def test_to_3d_expert_fp8_block():
"""FP8 block quantization: grouped scale, no packing."""
num_experts, output_size, input_size = 4, 64, 128
group_size = 32
num_row_groups = output_size # per-row
num_col_groups = input_size // group_size
layer = _make_layer(
num_experts,
output_size,
input_size,
weight_shape=(num_experts * output_size, input_size),
scale_shape=(num_experts * num_row_groups, num_col_groups),
)
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
assert layer.weight.shape == (num_experts, output_size, input_size)
assert layer.weight_scale.shape == (
num_experts,
num_row_groups,
num_col_groups,
)


def test_to_3d_expert_fp8_per_channel():
"""FP8 per-channel: no packing, scale per row."""
num_experts, output_size, input_size = 4, 64, 128
layer = _make_layer(
num_experts,
output_size,
input_size,
weight_shape=(num_experts * output_size, input_size),
scale_shape=(num_experts * output_size, 1),
)
GraniteMoeHybridParallelExpertsLinear.to_3d_expert(layer)
assert layer.weight.shape == (num_experts, output_size, input_size)
assert layer.weight_scale.shape == (num_experts, output_size, 1)