Skip to content

Commit 337e067

Browse files
committed
slightly account for offloading
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 4096ffd commit 337e067

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

src/llmcompressor/modeling/gpt_oss.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
88
from llmcompressor.utils.dev import skip_weights_initialize
99

10+
from compressed_tensors import update_offload_parameter
11+
1012

1113
class GptOssExpert(torch.nn.Module):
1214
def __init__(self, hidden_size: int, expert_dim: int, alpha: float, limit: float):
@@ -56,18 +58,42 @@ def __init__(self, experts: GptOssExpert):
5658

5759
def load_weights(self, experts: GptOssExperts):
5860
for expert_index, expert in enumerate(self.experts):
59-
expert.gate_proj.weight.data = experts.gate_up_proj[expert_index, ..., ::2].data.T
60-
expert.gate_proj.bias.data = experts.gate_up_proj_bias[expert_index, ..., ::2].data
61-
62-
expert.up_proj.weight.data = experts.gate_up_proj[expert_index, ..., 1::2].data.T
63-
expert.up_proj.bias.data = experts.gate_up_proj_bias[expert_index, ..., 1::2].data
61+
update_offload_parameter(expert.gate_proj, "weight", experts.gate_up_proj[expert_index, ..., ::2].T)
62+
update_offload_parameter(expert.gate_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., ::2])
6463

65-
expert.down_proj.weight.data = experts.down_proj[expert_index].T
66-
expert.down_proj.bias.data = experts.down_proj_bias[expert_index]
64+
update_offload_parameter(expert.up_proj, "weight", experts.gate_up_proj[expert_index, ..., 1::2].T)
65+
update_offload_parameter(expert.up_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., 1::2])
6766

67+
update_offload_parameter(expert.down_proj, "weight", experts.down_proj[expert_index].T)
68+
update_offload_parameter(expert.down_proj, "bias", experts.down_proj_bias[expert_index])
6869

6970
def to_original(self) -> GptOssExperts:
70-
pass
71+
with skip_weights_initialize():
72+
fake_config = GptOssConfig(
73+
intermediate_size=self.intermediate_size,
74+
num_local_experts=self.num_experts,
75+
hidden_size=self.hidden_size,
76+
77+
)
78+
experts = GptOssExperts(fake_config)
79+
80+
for expert_index, expert in enumerate(self.experts):
81+
experts.gate_up_proj[expert_index, ..., ::2].data = expert.gate_proj.weight.data.T
82+
experts.gate_up_proj_bias[expert_index, ..., ::2].data = expert.gate_proj.bias.data
83+
84+
experts.gate_up_proj[expert_index, ..., 1::2].data = expert.up_proj.weight.data.T
85+
experts.gate_up_proj_bias[expert_index, ..., 1::2].data = expert.up_proj.bias.data
86+
87+
experts.down_proj[expert_index].data = expert.down_proj.weight.data.T
88+
experts.down_proj_bias[expert_index] = expert.down_proj.bias.data
89+
90+
# update offloaded state dict
91+
update_offload_parameter(experts, "gate_up_proj", experts.gate_up_proj)
92+
update_offload_parameter(experts, "gate_up_proj_bias", experts.gate_up_proj_bias)
93+
update_offload_parameter(experts, "down_proj", experts.down_proj)
94+
update_offload_parameter(experts, "down_proj_bias", experts.down_proj_bias)
95+
96+
return experts
7197

7298

7399
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
@@ -113,5 +139,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
113139
linear = GptOssExpertsLinear(original)
114140
output = linear(input, routing_weights=routing_weights)
115141

116-
breakpoint()
117-
assert torch.allclose(output, true_output, atol=1e-3, rtol=0.0)
142+
assert torch.allclose(output, true_output, atol=1e-3, rtol=0.0)
143+
144+
restored = linear.to_original()
145+
restored_output = linear(input, routing_weights=routing_weights)
146+
assert torch.allclose(restored_output, true_output, atol=1e-3, rtol=0.0)

src/llmcompressor/modeling/prepare.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import contextlib
12
from compressed_tensors.utils import replace_module
23
from transformers import PreTrainedModel
34

45
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
56
from llmcompressor.modeling.llama4 import replace as replace_llama4
67
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
8+
from llmcompressor.modeling.gpt_oss import GptOssExpertsLinear
79
from llmcompressor.utils.helpers import patch_attr
810

911
__all__ = ["replace_modules_for_calibration"]
@@ -42,13 +44,29 @@ def update_qwen3_moe(model, stack):
4244
)
4345

4446

45-
def update_gpt_oss_moe(model, stack):
46-
47+
def update_gpt_oss_moe(model: PreTrainedModel, stack):
48+
@contextlib.contextmanager
49+
def replace_context(model, name, module):
50+
linear = GptOssExpertsLinear(module)
51+
replace_module(model, name, linear)
52+
del module
53+
54+
yield
55+
56+
restored = linear.to_original()
57+
replace_module(model, name, restored)
58+
59+
# TODO: need to think about duplicates
60+
for name, module in model.named_modules():
61+
cls_name = module.__class__.__name__
62+
if cls_name == "GptOssExpert":
63+
stack.enter_context(replace_context(model, name, module))
4764

4865

4966

5067
moe_context = {
5168
"Qwen3MoeForCausalLM": update_qwen3_moe,
69+
"GptOssForCausalLM": update_gpt_oss_moe,
5270
}
5371

5472

0 commit comments

Comments
 (0)