Skip to content

Commit 6bb7905

Browse files
isharif168dsikkashanjiaz
authored
[CPU] Linearize gpt_oss model and add example to quantize it to w4a8 (vllm-project#2113)
SUMMARY: Add code to convert the gpt_oss model and then quantize model to w8a8 TEST PLAN: Tested with gpt_oss model and run using vllm Sample output === Prompt 0 === <|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI. Knowledge cutoff: 2024-06 Current date: 2025-12-04 Reasoning: medium Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions You are a helpful assistant. <|end|><|start|>user<|message|>Give 3 reasons to use AI.<|end|><|start|>assistant --- Candidate 0 --- analysisWe need to answer with three reasons to use AI. Should be short. Provide 3 reasons. Let's do it.assistantcommentaryHere are three reasons to use AI: Personalized learning – AI can adapt to your learning style and pace, providing customized lessons that suit your learning style and progress. Skill development – AI can help you develop skills that are in demand, such as coding, data analysis, and data science. Supportive design We should provide a short answer.assistantfinalHere are three reasons to use AI: Personalized learning – AI adapts to your learning style and pace, giving you personalized lessons that fit your needs. Skill development – AI helps you learn and practice skills that are in demand, like coding, data analysis, and data‑science skills. Real‑world impact – AI helps you practice real-world skills in real life, such as learning new skills and applying them in real life. finish_reason: stop num_tokens: 226 --------- Signed-off-by: Sharif Inamdar <sharif.inamdar@arm.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
1 parent 6264c59 commit 6bb7905

File tree

2 files changed

+338
-0
lines changed

2 files changed

+338
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
from compressed_tensors.quantization import QuantizationScheme
3+
from compressed_tensors.quantization.quant_args import (
4+
QuantizationArgs,
5+
QuantizationStrategy,
6+
QuantizationType,
7+
)
8+
from transformers import AutoModelForCausalLM, AutoTokenizer
9+
10+
from llmcompressor import oneshot
11+
from llmcompressor.modeling.gpt_oss import convert_model_for_quantization_gptoss
12+
from llmcompressor.modifiers.quantization import QuantizationModifier
13+
14+
15+
def main():
16+
MODEL_ID = "openai/gpt-oss-20b"
17+
BASE_NAME = MODEL_ID.rstrip("/").split("/")[-1]
18+
OUTPUT_DIR = f"{BASE_NAME}-w4a8-channelwise"
19+
20+
print(f"[GPT-OSS] Loading model: {MODEL_ID}")
21+
model = AutoModelForCausalLM.from_pretrained(
22+
MODEL_ID,
23+
torch_dtype=torch.bfloat16,
24+
device_map="auto",
25+
trust_remote_code=True,
26+
)
27+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
28+
29+
# ---- GPT-OSS MoE → linear experts conversion ----
30+
print("[GPT-OSS] Converting fused MoE experts to LinearExperts for quantization...")
31+
convert_model_for_quantization_gptoss(model)
32+
print("[GPT-OSS] Conversion completed.")
33+
34+
# ---- Quantization config: W4A8 (int4 weights, int8 activations) ----
35+
36+
# Weights: 4-bit, channelwise, symmetric, static
37+
weights_args = QuantizationArgs(
38+
num_bits=4,
39+
type=QuantizationType.INT,
40+
strategy=QuantizationStrategy.CHANNEL,
41+
symmetric=True,
42+
dynamic=False,
43+
)
44+
45+
# Activations: 8-bit, per-token, asymmetric, dynamic
46+
activations_args = QuantizationArgs(
47+
num_bits=8,
48+
type=QuantizationType.INT,
49+
strategy=QuantizationStrategy.TOKEN,
50+
symmetric=False,
51+
dynamic=True,
52+
observer=None,
53+
)
54+
55+
# Apply to all Linear layers, excluding lm_head
56+
scheme = QuantizationScheme(
57+
targets=["Linear"],
58+
weights=weights_args,
59+
input_activations=activations_args,
60+
)
61+
62+
recipe = QuantizationModifier(
63+
config_groups={"group_0": scheme},
64+
ignore=["lm_head"],
65+
)
66+
67+
print(f"[GPT-OSS] Starting oneshot quantization → {OUTPUT_DIR}")
68+
oneshot(
69+
model=model,
70+
recipe=recipe,
71+
tokenizer=tokenizer,
72+
output_dir=OUTPUT_DIR,
73+
trust_remote_code_model=True,
74+
)
75+
print(f"[GPT-OSS] Quantization finished. Quantized model written to: {OUTPUT_DIR}")
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import List, Optional
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
10+
class LinearExpert(nn.Module):
11+
"""
12+
One MoE expert with separate gate / up / down projections.
13+
14+
This mirrors the GPT-OSS expert behavior:
15+
gate = clamp(gate_proj(x))
16+
up = clamp(up_proj(x))
17+
glu = gate * sigmoid(alpha * gate)
18+
y = down_proj((up + 1) * glu)
19+
"""
20+
21+
def __init__(
22+
self, hidden_size: int, intermediate_size: int, alpha: float, limit: float
23+
):
24+
super().__init__()
25+
self.alpha = alpha
26+
self.limit = limit
27+
28+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
29+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
30+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
31+
32+
def forward(self, x: torch.Tensor) -> torch.Tensor:
33+
gate = self.gate_proj(x)
34+
up = self.up_proj(x)
35+
36+
gate = gate.clamp(max=self.limit)
37+
up = up.clamp(min=-self.limit, max=self.limit)
38+
39+
glu = gate * torch.sigmoid(self.alpha * gate)
40+
act = (up + 1) * glu
41+
return self.down_proj(act)
42+
43+
44+
class LinearExperts(nn.Module):
45+
"""
46+
Container of multiple LinearExpert modules, driven by
47+
router_indices / routing_weights.
48+
49+
This is the "separate gate/up" layout.
50+
It is meant to replace the original GPT-OSS `experts` submodule.
51+
"""
52+
53+
def __init__(
54+
self,
55+
hidden_size: int,
56+
intermediate_size: int,
57+
num_experts: int,
58+
alpha: float = 1.702,
59+
limit: float = 7.0,
60+
):
61+
super().__init__()
62+
self.hidden_size = hidden_size
63+
self.expert_dim = intermediate_size
64+
self.num_experts = num_experts
65+
self.alpha = alpha
66+
self.limit = limit
67+
68+
self.experts = nn.ModuleList(
69+
[
70+
LinearExpert(hidden_size, intermediate_size, alpha, limit)
71+
for _ in range(num_experts)
72+
]
73+
)
74+
75+
@torch.no_grad()
76+
def copy_from_fused_weights(
77+
self,
78+
legacy_gate_up_W: torch.Tensor, # [E, H, 2D]
79+
legacy_gate_up_b: torch.Tensor, # [E, 2D]
80+
legacy_down_W: torch.Tensor, # [E, D, H]
81+
legacy_down_b: torch.Tensor, # [E, H]
82+
) -> None:
83+
"""
84+
De-interleave fused gate_up weights/bias and copy into separate gate/up experts.
85+
"""
86+
E, H, twoD = legacy_gate_up_W.shape
87+
assert E == self.num_experts
88+
D = twoD // 2
89+
assert D == self.expert_dim
90+
91+
for i in range(E):
92+
Wi = legacy_gate_up_W[i] # [H, 2D]
93+
bi = legacy_gate_up_b[i] # [2D]
94+
95+
Wg = Wi[:, 0::2].contiguous() # [H, D]
96+
Wu = Wi[:, 1::2].contiguous() # [H, D]
97+
bg = bi[0::2].contiguous() # [D]
98+
bu = bi[1::2].contiguous() # [D]
99+
100+
expert = self.experts[i]
101+
expert.gate_proj.weight.copy_(Wg.t())
102+
expert.gate_proj.bias.copy_(bg)
103+
expert.up_proj.weight.copy_(Wu.t())
104+
expert.up_proj.bias.copy_(bu)
105+
106+
expert.down_proj.weight.copy_(legacy_down_W[i].t())
107+
expert.down_proj.bias.copy_(legacy_down_b[i])
108+
109+
def forward(
110+
self,
111+
hidden_states: torch.Tensor, # [B, T, H]
112+
router_indices: Optional[
113+
torch.Tensor
114+
] = None, # [B, T, top_k] or [tokens, top_k]
115+
routing_weights: Optional[torch.Tensor] = None, # [B, T, E] or [tokens, E]
116+
) -> torch.Tensor:
117+
"""
118+
Implements the MoE computation using the router outputs.
119+
120+
This is compatible with the GPT-OSS MoE call pattern:
121+
experts(hidden_states, router_indices, routing_weights)
122+
"""
123+
assert (
124+
routing_weights is not None and router_indices is not None
125+
), "router inputs required"
126+
127+
# Normalize shapes to [tokens, H], [tokens, top_k], [tokens, E]
128+
if hidden_states.dim() == 3:
129+
B, T, H = hidden_states.shape
130+
x = hidden_states.reshape(-1, H)
131+
else:
132+
# Already flattened
133+
B, _ = 1, hidden_states.shape[0]
134+
H = hidden_states.shape[-1]
135+
x = hidden_states
136+
137+
if router_indices.dim() == 3:
138+
router_indices = router_indices.reshape(-1, router_indices.shape[-1])
139+
if routing_weights.dim() == 3:
140+
routing_weights = routing_weights.reshape(-1, routing_weights.shape[-1])
141+
142+
num_experts_plus_dummy = routing_weights.shape[1]
143+
out = torch.zeros_like(x)
144+
145+
# GPT-OSS router uses an extra "no expert" bucket at index E
146+
with torch.no_grad():
147+
expert_mask = torch.nn.functional.one_hot(
148+
router_indices, num_classes=num_experts_plus_dummy
149+
).permute(2, 1, 0)
150+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
151+
152+
for idx in expert_hit:
153+
e = idx[0].item()
154+
if e == self.num_experts:
155+
# Skip "no expert" bucket
156+
continue
157+
158+
_, token_idx = torch.where(expert_mask[e])
159+
xi = x[token_idx]
160+
161+
expert = self.experts[e]
162+
yi = expert(xi)
163+
164+
w = routing_weights[token_idx, e, None]
165+
out.index_add_(0, token_idx, (yi * w).to(out.dtype))
166+
167+
return out.view(B, -1, H)
168+
169+
170+
@dataclass
171+
class ExpertMeta:
172+
path: str
173+
hidden_size: int
174+
intermediate_size: int
175+
num_experts: int
176+
device: torch.device
177+
dtype: torch.dtype
178+
179+
180+
def get_module_by_path(root: nn.Module, dotpath: str) -> nn.Module:
181+
m: nn.Module = root
182+
if not dotpath:
183+
return root
184+
for p in dotpath.split("."):
185+
m = getattr(m, p)
186+
return m
187+
188+
189+
def set_module_by_path(root: nn.Module, dotpath: str, new_module: nn.Module) -> None:
190+
parts = dotpath.split(".")
191+
parent = get_module_by_path(root, ".".join(parts[:-1]))
192+
setattr(parent, parts[-1], new_module)
193+
194+
195+
def find_experts(model: nn.Module) -> List[ExpertMeta]:
196+
"""
197+
Locate GPT-OSS MoE expert modules under model.model.layers[*].mlp.experts.
198+
"""
199+
metas: List[ExpertMeta] = []
200+
for li, layer in enumerate(model.model.layers):
201+
experts = layer.mlp.experts
202+
device = next(experts.parameters(), torch.zeros(())).device
203+
dtype = next(experts.parameters(), torch.zeros(())).dtype
204+
intermediate = getattr(experts, "expert_dim", None)
205+
if intermediate is None:
206+
intermediate = getattr(experts, "intermediate_size")
207+
208+
metas.append(
209+
ExpertMeta(
210+
path=f"model.layers.{li}.mlp.experts",
211+
hidden_size=experts.hidden_size,
212+
intermediate_size=intermediate,
213+
num_experts=experts.num_experts,
214+
device=device,
215+
dtype=dtype,
216+
)
217+
)
218+
return metas
219+
220+
221+
def convert_model_for_quantization_gptoss(model: nn.Module) -> None:
222+
"""
223+
In-place conversion of a GPT-OSS model:
224+
225+
- Finds all fused MoE expert blocks (with gate_up_proj/down_proj).
226+
- Replaces them with LinearExperts that expose plain nn.Linear
227+
parameters (gate_proj, up_proj, down_proj), which play nicely
228+
with LLM Compressor W4A8 quantization.
229+
"""
230+
metas = find_experts(model)
231+
for meta in metas:
232+
legacy = get_module_by_path(model, meta.path)
233+
234+
# Sanity check that this is the fused layout we expect.
235+
if not all(
236+
hasattr(legacy, attr)
237+
for attr in [
238+
"gate_up_proj",
239+
"gate_up_proj_bias",
240+
"down_proj",
241+
"down_proj_bias",
242+
]
243+
):
244+
continue
245+
246+
new_exp = LinearExperts(
247+
hidden_size=meta.hidden_size,
248+
intermediate_size=meta.intermediate_size,
249+
num_experts=meta.num_experts,
250+
).to(device=meta.device, dtype=meta.dtype)
251+
252+
new_exp.copy_from_fused_weights(
253+
legacy_gate_up_W=legacy.gate_up_proj,
254+
legacy_gate_up_b=legacy.gate_up_proj_bias,
255+
legacy_down_W=legacy.down_proj,
256+
legacy_down_b=legacy.down_proj_bias,
257+
)
258+
259+
set_module_by_path(model, meta.path, new_exp)

0 commit comments

Comments
 (0)