From affa73f2ff58ea0b03425b75f9ef8e334c5e4dfc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 7 Aug 2025 04:30:00 +0000 Subject: [PATCH 1/3] working, jank Signed-off-by: Kyle Sayers --- .../quantization_w4a16/gpt_oss_example.py | 82 ++++++++ src/llmcompressor/modeling/gpt_oss.py | 181 ++++++++++++++++++ src/llmcompressor/modeling/prepare.py | 122 +++++++++++- .../pipelines/sequential/pipeline.py | 54 +++--- 4 files changed, 408 insertions(+), 31 deletions(-) create mode 100644 examples/quantization_w4a16/gpt_oss_example.py create mode 100644 src/llmcompressor/modeling/gpt_oss.py diff --git a/examples/quantization_w4a16/gpt_oss_example.py b/examples/quantization_w4a16/gpt_oss_example.py new file mode 100644 index 000000000..59bc0e198 --- /dev/null +++ b/examples/quantization_w4a16/gpt_oss_example.py @@ -0,0 +1,82 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modeling import replace_modules_for_calibration +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor import oneshot +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +model_id = "unsloth/gpt-oss-20b-BF16" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id) +#replace_modules_for_calibration(model) # linearize experts so they can be targeted + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 1#512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + calibrate_moe_context=True, + pipeline="sequential", +) + +# # Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py new file mode 100644 index 000000000..fd986404d --- /dev/null +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -0,0 +1,181 @@ +from typing import List + +import torch +import contextlib + +from transformers import GptOssForCausalLM +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts +from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig +from llmcompressor.utils.dev import skip_weights_initialize + +from compressed_tensors.utils import update_offload_parameter, align_module_device + + +class GptOssExpert(torch.nn.Module): + gate_proj: torch.nn.Linear + up_proj: torch.nn.Linear + down_proj: torch.nn.Linear + + def __init__(self, experts: GptOssExperts): + super().__init__() + + self.hidden_size = experts.hidden_size + self.expert_dim = experts.expert_dim + self.alpha = experts.alpha + self.limit = experts.limit + + assert experts.gate_up_proj.dtype == experts.gate_up_proj_bias.dtype + assert experts.down_proj.dtype == experts.down_proj_bias.dtype + + with skip_weights_initialize(): + self.gate_proj = torch.nn.Linear(self.hidden_size, self.expert_dim, bias=True, dtype=experts.gate_up_proj.dtype) + self.up_proj = torch.nn.Linear(self.hidden_size, self.expert_dim, bias=True, dtype=experts.gate_up_proj.dtype) + self.down_proj = torch.nn.Linear(self.expert_dim, self.hidden_size, bias=True, dtype=experts.down_proj.dtype) + + def forward(self, hidden_states: torch.Tensor): + gate = self.gate_proj(hidden_states) + gate = gate.clamp(min=None, max=self.limit) + + up = self.up_proj(hidden_states) + up = up.clamp(min=-self.limit, max=self.limit) + + glu = gate * torch.sigmoid(gate * self.alpha) + return self.down_proj((up + 1) * glu) + + + +class GptOssExpertsLinear(torch.nn.Module): + experts: List[GptOssExpert] + + def __init__(self, experts: GptOssExperts): + super().__init__() + + self.intermediate_size = experts.intermediate_size + self.num_experts = experts.num_experts + self.hidden_size = experts.hidden_size + self.expert_dim = experts.expert_dim + + with skip_weights_initialize(): + self.experts = torch.nn.ModuleList([GptOssExpert(experts) for _ in range(self.num_experts)]) + + self.load_weights(experts) + + self.alpha = experts.alpha + self.limit = experts.limit + + def load_weights(self, experts: GptOssExperts): + with align_module_device(experts): + for expert_index, expert in enumerate(self.experts): + update_offload_parameter(expert.gate_proj, "weight", experts.gate_up_proj[expert_index, ..., ::2].T) + update_offload_parameter(expert.gate_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., ::2]) + + update_offload_parameter(expert.up_proj, "weight", experts.gate_up_proj[expert_index, ..., 1::2].T) + update_offload_parameter(expert.up_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., 1::2]) + + update_offload_parameter(expert.down_proj, "weight", experts.down_proj[expert_index].T) + update_offload_parameter(expert.down_proj, "bias", experts.down_proj_bias[expert_index]) + + def to_original(self) -> GptOssExperts: + # TODO: this doesn't really handle offloading or correct device placement + with skip_weights_initialize(use_zeros=True): + fake_config = GptOssConfig( + intermediate_size=self.intermediate_size, + num_local_experts=self.num_experts, + hidden_size=self.hidden_size, + ) + experts = GptOssExperts(fake_config) + experts.gate_up_proj = torch.nn.Parameter(experts.gate_up_proj.to(dtype=self.experts[0].gate_proj.weight.dtype), requires_grad=False) + experts.gate_up_proj_bias = torch.nn.Parameter(experts.gate_up_proj_bias.to(dtype=self.experts[0].gate_proj.weight.dtype), requires_grad=False) + experts.down_proj = torch.nn.Parameter(experts.down_proj.to(dtype=self.experts[0].down_proj.weight.dtype), requires_grad=False) + experts.down_proj_bias = torch.nn.Parameter(experts.down_proj_bias.to(dtype=self.experts[0].down_proj.weight.dtype), requires_grad=False) + + for expert_index, expert in enumerate(self.experts): + with align_module_device(expert.gate_proj, "cpu"), align_module_device(expert.up_proj, "cpu"), align_module_device(expert.down_proj, "cpu"): + experts.gate_up_proj[expert_index, ..., ::2].copy_(expert.gate_proj.weight.data.T) + experts.gate_up_proj_bias[expert_index, ..., ::2].copy_(expert.gate_proj.bias.data) + + experts.gate_up_proj[expert_index, ..., 1::2].copy_(expert.up_proj.weight.data.T) + experts.gate_up_proj_bias[expert_index, ..., 1::2].copy_(expert.up_proj.bias.data) + + experts.down_proj[expert_index].copy_(expert.down_proj.weight.data.T) + experts.down_proj_bias[expert_index].copy_(expert.down_proj.bias.data) + + print("converted, for some reason slows down over time") + import time + print(time.time()) + + experts.eval() + return experts + + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + """ + When training is is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + Returns: + torch.Tensor + """ + original_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + for expert_index, expert in enumerate(self.experts): + next_states += expert(hidden_states) * routing_weights.T[expert_index].unsqueeze(-1) + + next_states = next_states.reshape(original_shape) + return next_states + +def replace_gpt_oss(config: GptOssConfig, module: GptOssExpert): + return GptOssExpertsLinear(module) + + +def test_restore(): + config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5) + + original = GptOssExperts(config) + linear = GptOssExpertsLinear(original) + + restored = linear.to_original() + for param_name, param in original.named_parameters(recurse=False): + restored_param = getattr(restored, param_name) + assert param.shape == restored_param.shape + assert param.dtype == restored_param.dtype + + assert torch.all(getattr(restored, param_name) == param) + + +def test_correctness(): + batch_size, seq_len = 13, 12 + config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5) + + input = torch.rand((batch_size, seq_len, config.hidden_size)) + routing_weights = torch.rand((batch_size * seq_len, config.num_local_experts)) + + with torch.no_grad(): + original = GptOssExperts(config) + for name in ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias"]: + setattr(original, name, getattr(original, name).normal_()) + + original.eval() + assert original.training == False + true_output = original(input, routing_weights=routing_weights) + + linear = GptOssExpertsLinear(original) + output = linear(input, routing_weights=routing_weights) + + assert torch.allclose(output, true_output, atol=1e-3, rtol=0.0) + + restored = linear.to_original() + restored_output = restored(input, routing_weights=routing_weights) + assert torch.allclose(restored_output, true_output, atol=1e-3, rtol=0.0) + + +if __name__ == "__main__": + test_restore() \ No newline at end of file diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index cb61f5fad..609880c89 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,10 +1,16 @@ -from compressed_tensors.utils import replace_module +import contextlib +import tqdm +from compressed_tensors.utils import replace_module, delete_offload_module, register_offload_module, get_offloaded_device +from compressed_tensors.utils.offload import offload_to_weights_map from transformers import PreTrainedModel from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE +from llmcompressor.modeling.gpt_oss import GptOssExpertsLinear, replace_gpt_oss from llmcompressor.utils.helpers import patch_attr +from accelerate.hooks import add_hook_to_module, remove_hook_from_module, AlignDevicesHook, named_module_tensors, set_module_tensor_to_device, PrefixedDataset +from accelerate.big_modeling import attach_align_device_hook_on_blocks __all__ = ["replace_modules_for_calibration"] @@ -12,11 +18,13 @@ replacements = { "DeepseekV3MoE": replace_deepseekv3, "Llama4TextMoe": replace_llama4, + "GptOssExperts": replace_gpt_oss, } def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: - for name, module in model.named_modules(): + modules = list(model.named_modules()) + for name, module in tqdm.tqdm(modules, desc="Converting modules"): cls_name = module.__class__.__name__ if cls_name in replacements: new_module = replacements[cls_name](config=model.config, module=module) @@ -28,8 +36,9 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: # ------------------- module replacements; during calibration -------------------- -def update_qwen3_moe(model, stack): - for module in model.modules(): +def update_qwen3_moe(model: PreTrainedModel, stack): + modules = list(model.modules()) + for module in tqdm.tqdm(modules, desc="Converting modules"): cls_name = module.__class__.__name__ if cls_name == "Qwen3MoeDecoderLayer": # Optionally update the model.config to pass in other arguments @@ -41,9 +50,36 @@ def update_qwen3_moe(model, stack): ) ) +def update_gpt_oss(model: PreTrainedModel, stack): + @contextlib.contextmanager + def replace(mod_name, module, name, original): + hook: AlignDevicesHook = original._hf_hook + + replacement = replace_gpt_oss(model.config, original) + replace_offload_module(module, name, hook, replacement) + del original + + yield + + restored = replacement.to_original() + delattr(module, name) + module.register_module(name, restored) + #replace_offload_module(module, name, hook, restored) + del replacement + + + modules = list(model.named_modules()) + for name, module in tqdm.tqdm(modules, desc="Converting modules"): + for child_name, child in list(module.named_children()): + if child.__class__.__name__ == "GptOssExperts": + stack.enter_context(replace(name, module, child_name, child)) + + + moe_context = { "Qwen3MoeForCausalLM": update_qwen3_moe, + "GptOssForCausalLM": update_gpt_oss, } @@ -53,3 +89,81 @@ def moe_calibration_context(model: PreTrainedModel, stack): cls_name = model.__class__.__name__ if cls_name in moe_context: moe_context.get(cls_name)(model, stack) + + + + +def replace_offload_module(base, name: str, hook: AlignDevicesHook, module): + delattr(base, name) + + assert hook.offload + assert hook.weights_map is not None + + # offload parameters to weights map + offload_device = "cpu" + for param_name, param in named_module_tensors( + module, include_buffers=hook.offload_buffers, recurse=True + ): + offloaded = param.to(offload_device) + if hook.tied_params_map is not None: + hook.tied_params_map[offloaded.data_ptr()] = {} # (1) + offload_to_weights_map(hook.weights_map, param_name, offloaded) + + # attach hooks and offload weights + attach_align_device_hook_on_blocks( + module, + hook.execution_device, + hook.offload, + hook.weights_map, + hook.offload_buffers, + "", + hook.skip_keys, + None, + hook.tied_params_map, + ) + + base.register_module(name, module) + + + # # offloading kwargs for submodule + # place_submodules = False + # offload_buffers = True + + # # copy device offloading arguments from parent + # current_device = next(base.parameters()).device # assume base has parameters + # offload_device = get_offloaded_device(base) + + # # offload parameters to weights map + # for param_name, param in named_module_tensors( + # module, include_buffers=offload_buffers, recurse=place_submodules + # ): + # offloaded = param.to(offload_device) + # if hook.tied_params_map is not None: + # hook.tied_params_map[offloaded.data_ptr()] = {} # (1) + # offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded) + + # # if the parent places submodules, offload here + # if hook.place_submodules: + # set_module_tensor_to_device(module, param_name, current_device) + + # if not hook.place_submodules: + # weights_map = PrefixedDataset( + # hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}." + # ) + + # submodule_hook = AlignDevicesHook( + # execution_device=hook.execution_device, + # offload=hook.offload, + # io_same_device=False, + # weights_map=weights_map, + # offload_buffers=offload_buffers, + # place_submodules=place_submodules, + # skip_keys=None, + # tied_params_map=hook.tied_params_map, + # ) + # add_hook_to_module(module, submodule_hook) + + # base.register_module(name, module) + # for c_name, child in list(module.named_children()): + # register_offload_module(module, c_name, child) + # replace_offload_module(module, None, c_name, child, child) \ No newline at end of file diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 901283252..bdc260003 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -80,33 +80,33 @@ def __call__( if dataset_args.calibrate_moe_context: moe_calibration_context(model, stack) - # prepare intermediates cache - activations = IntermediatesCache.from_dataloader(dataloader, model_device) - - for subgraph_index, subgraph in enumerate(subgraphs): - # prepare tqdm description texts - calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" - prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" - - # reduce memory movement by keeping modules onloaded - with disable_offloading(): - # do a preliminary pass to trigger modifier hooks - for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc): - inputs = activations.fetch(batch_idx, subgraph.input_names) - subgraph.forward(model, **inputs) - - LifecycleCallbacks.sequential_epoch_end() - - # this pass does not trigger modifier hooks - # and is only used for capturing outputs of newly compressed modules - with HooksMixin.disable_hooks(): - for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc): - inputs = activations.fetch(batch_idx, subgraph.input_names) - output = subgraph.forward(model, **inputs) - - if subgraph_index < num_subgraphs - 1: - activations.update(batch_idx, output) - activations.delete(batch_idx, subgraph.consumed_names) + # # prepare intermediates cache + # activations = IntermediatesCache.from_dataloader(dataloader, model_device) + + # for subgraph_index, subgraph in enumerate(subgraphs): + # # prepare tqdm description texts + # calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" + # prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" + + # # reduce memory movement by keeping modules onloaded + # with disable_offloading(): + # # do a preliminary pass to trigger modifier hooks + # for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc): + # inputs = activations.fetch(batch_idx, subgraph.input_names) + # subgraph.forward(model, **inputs) + + # LifecycleCallbacks.sequential_epoch_end() + + # # this pass does not trigger modifier hooks + # # and is only used for capturing outputs of newly compressed modules + # with HooksMixin.disable_hooks(): + # for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc): + # inputs = activations.fetch(batch_idx, subgraph.input_names) + # output = subgraph.forward(model, **inputs) + + # if subgraph_index < num_subgraphs - 1: + # activations.update(batch_idx, output) + # activations.delete(batch_idx, subgraph.consumed_names) # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end() From 119b7c2b95e2c398b08b3c6364990d9ffe5e2133 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 7 Aug 2025 04:32:59 +0000 Subject: [PATCH 2/3] clean up some jank Signed-off-by: Kyle Sayers --- .../quantization_w4a16/gpt_oss_example.py | 6 +- src/llmcompressor/modeling/gpt_oss.py | 148 +++++++++++++----- src/llmcompressor/modeling/prepare.py | 63 +------- .../pipelines/sequential/pipeline.py | 54 +++---- 4 files changed, 147 insertions(+), 124 deletions(-) diff --git a/examples/quantization_w4a16/gpt_oss_example.py b/examples/quantization_w4a16/gpt_oss_example.py index 59bc0e198..142f0fbb6 100644 --- a/examples/quantization_w4a16/gpt_oss_example.py +++ b/examples/quantization_w4a16/gpt_oss_example.py @@ -1,16 +1,14 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from llmcompressor.modeling import replace_modules_for_calibration -from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.utils import dispatch_for_generation # Select model and load it. model_id = "unsloth/gpt-oss-20b-BF16" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) -#replace_modules_for_calibration(model) # linearize experts so they can be targeted # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" @@ -18,7 +16,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 1#512 +NUM_CALIBRATION_SAMPLES = 1 # 512 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index fd986404d..a8044a3b4 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -1,14 +1,12 @@ +# flake8: noqa from typing import List import torch -import contextlib - -from transformers import GptOssForCausalLM -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts +from compressed_tensors.utils import align_module_device, update_offload_parameter from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig -from llmcompressor.utils.dev import skip_weights_initialize +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts -from compressed_tensors.utils import update_offload_parameter, align_module_device +from llmcompressor.utils.dev import skip_weights_initialize class GptOssExpert(torch.nn.Module): @@ -28,9 +26,24 @@ def __init__(self, experts: GptOssExperts): assert experts.down_proj.dtype == experts.down_proj_bias.dtype with skip_weights_initialize(): - self.gate_proj = torch.nn.Linear(self.hidden_size, self.expert_dim, bias=True, dtype=experts.gate_up_proj.dtype) - self.up_proj = torch.nn.Linear(self.hidden_size, self.expert_dim, bias=True, dtype=experts.gate_up_proj.dtype) - self.down_proj = torch.nn.Linear(self.expert_dim, self.hidden_size, bias=True, dtype=experts.down_proj.dtype) + self.gate_proj = torch.nn.Linear( + self.hidden_size, + self.expert_dim, + bias=True, + dtype=experts.gate_up_proj.dtype, + ) + self.up_proj = torch.nn.Linear( + self.hidden_size, + self.expert_dim, + bias=True, + dtype=experts.gate_up_proj.dtype, + ) + self.down_proj = torch.nn.Linear( + self.expert_dim, + self.hidden_size, + bias=True, + dtype=experts.down_proj.dtype, + ) def forward(self, hidden_states: torch.Tensor): gate = self.gate_proj(hidden_states) @@ -42,7 +55,6 @@ def forward(self, hidden_states: torch.Tensor): glu = gate * torch.sigmoid(gate * self.alpha) return self.down_proj((up + 1) * glu) - class GptOssExpertsLinear(torch.nn.Module): experts: List[GptOssExpert] @@ -56,7 +68,9 @@ def __init__(self, experts: GptOssExperts): self.expert_dim = experts.expert_dim with skip_weights_initialize(): - self.experts = torch.nn.ModuleList([GptOssExpert(experts) for _ in range(self.num_experts)]) + self.experts = torch.nn.ModuleList( + [GptOssExpert(experts) for _ in range(self.num_experts)] + ) self.load_weights(experts) @@ -66,14 +80,34 @@ def __init__(self, experts: GptOssExperts): def load_weights(self, experts: GptOssExperts): with align_module_device(experts): for expert_index, expert in enumerate(self.experts): - update_offload_parameter(expert.gate_proj, "weight", experts.gate_up_proj[expert_index, ..., ::2].T) - update_offload_parameter(expert.gate_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., ::2]) - - update_offload_parameter(expert.up_proj, "weight", experts.gate_up_proj[expert_index, ..., 1::2].T) - update_offload_parameter(expert.up_proj, "bias", experts.gate_up_proj_bias[expert_index, ..., 1::2]) - - update_offload_parameter(expert.down_proj, "weight", experts.down_proj[expert_index].T) - update_offload_parameter(expert.down_proj, "bias", experts.down_proj_bias[expert_index]) + update_offload_parameter( + expert.gate_proj, + "weight", + experts.gate_up_proj[expert_index, ..., ::2].T, + ) + update_offload_parameter( + expert.gate_proj, + "bias", + experts.gate_up_proj_bias[expert_index, ..., ::2], + ) + + update_offload_parameter( + expert.up_proj, + "weight", + experts.gate_up_proj[expert_index, ..., 1::2].T, + ) + update_offload_parameter( + expert.up_proj, + "bias", + experts.gate_up_proj_bias[expert_index, ..., 1::2], + ) + + update_offload_parameter( + expert.down_proj, "weight", experts.down_proj[expert_index].T + ) + update_offload_parameter( + expert.down_proj, "bias", experts.down_proj_bias[expert_index] + ) def to_original(self) -> GptOssExperts: # TODO: this doesn't really handle offloading or correct device placement @@ -84,31 +118,57 @@ def to_original(self) -> GptOssExperts: hidden_size=self.hidden_size, ) experts = GptOssExperts(fake_config) - experts.gate_up_proj = torch.nn.Parameter(experts.gate_up_proj.to(dtype=self.experts[0].gate_proj.weight.dtype), requires_grad=False) - experts.gate_up_proj_bias = torch.nn.Parameter(experts.gate_up_proj_bias.to(dtype=self.experts[0].gate_proj.weight.dtype), requires_grad=False) - experts.down_proj = torch.nn.Parameter(experts.down_proj.to(dtype=self.experts[0].down_proj.weight.dtype), requires_grad=False) - experts.down_proj_bias = torch.nn.Parameter(experts.down_proj_bias.to(dtype=self.experts[0].down_proj.weight.dtype), requires_grad=False) + experts.gate_up_proj = torch.nn.Parameter( + experts.gate_up_proj.to(dtype=self.experts[0].gate_proj.weight.dtype), + requires_grad=False, + ) + experts.gate_up_proj_bias = torch.nn.Parameter( + experts.gate_up_proj_bias.to( + dtype=self.experts[0].gate_proj.weight.dtype + ), + requires_grad=False, + ) + experts.down_proj = torch.nn.Parameter( + experts.down_proj.to(dtype=self.experts[0].down_proj.weight.dtype), + requires_grad=False, + ) + experts.down_proj_bias = torch.nn.Parameter( + experts.down_proj_bias.to(dtype=self.experts[0].down_proj.weight.dtype), + requires_grad=False, + ) for expert_index, expert in enumerate(self.experts): - with align_module_device(expert.gate_proj, "cpu"), align_module_device(expert.up_proj, "cpu"), align_module_device(expert.down_proj, "cpu"): - experts.gate_up_proj[expert_index, ..., ::2].copy_(expert.gate_proj.weight.data.T) - experts.gate_up_proj_bias[expert_index, ..., ::2].copy_(expert.gate_proj.bias.data) - - experts.gate_up_proj[expert_index, ..., 1::2].copy_(expert.up_proj.weight.data.T) - experts.gate_up_proj_bias[expert_index, ..., 1::2].copy_(expert.up_proj.bias.data) + with align_module_device(expert.gate_proj, "cpu"), align_module_device( + expert.up_proj, "cpu" + ), align_module_device(expert.down_proj, "cpu"): + experts.gate_up_proj[expert_index, ..., ::2].copy_( + expert.gate_proj.weight.data.T + ) + experts.gate_up_proj_bias[expert_index, ..., ::2].copy_( + expert.gate_proj.bias.data + ) + + experts.gate_up_proj[expert_index, ..., 1::2].copy_( + expert.up_proj.weight.data.T + ) + experts.gate_up_proj_bias[expert_index, ..., 1::2].copy_( + expert.up_proj.bias.data + ) experts.down_proj[expert_index].copy_(expert.down_proj.weight.data.T) experts.down_proj_bias[expert_index].copy_(expert.down_proj.bias.data) print("converted, for some reason slows down over time") import time + print(time.time()) experts.eval() return experts - - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: """ When training is is more efficient to just loop over the experts and compute the output for each expert as otherwise the memory would explode. @@ -123,15 +183,22 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig torch.Tensor """ original_shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + hidden_states = hidden_states.reshape( + -1, self.hidden_size + ) # (num_tokens, hidden_size) - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + next_states = torch.zeros_like( + hidden_states, dtype=hidden_states.dtype, device=hidden_states.device + ) for expert_index, expert in enumerate(self.experts): - next_states += expert(hidden_states) * routing_weights.T[expert_index].unsqueeze(-1) + next_states += expert(hidden_states) * routing_weights.T[ + expert_index + ].unsqueeze(-1) next_states = next_states.reshape(original_shape) return next_states - + + def replace_gpt_oss(config: GptOssConfig, module: GptOssExpert): return GptOssExpertsLinear(module) @@ -160,11 +227,16 @@ def test_correctness(): with torch.no_grad(): original = GptOssExperts(config) - for name in ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias"]: + for name in [ + "gate_up_proj", + "gate_up_proj_bias", + "down_proj", + "down_proj_bias", + ]: setattr(original, name, getattr(original, name).normal_()) original.eval() - assert original.training == False + assert not original.training true_output = original(input, routing_weights=routing_weights) linear = GptOssExpertsLinear(original) @@ -178,4 +250,4 @@ def test_correctness(): if __name__ == "__main__": - test_restore() \ No newline at end of file + test_restore() diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 609880c89..86cd66cb0 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,16 +1,17 @@ import contextlib + import tqdm -from compressed_tensors.utils import replace_module, delete_offload_module, register_offload_module, get_offloaded_device +from accelerate.big_modeling import attach_align_device_hook_on_blocks +from accelerate.hooks import AlignDevicesHook, named_module_tensors +from compressed_tensors.utils import replace_module from compressed_tensors.utils.offload import offload_to_weights_map from transformers import PreTrainedModel from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 +from llmcompressor.modeling.gpt_oss import replace_gpt_oss from llmcompressor.modeling.llama4 import replace as replace_llama4 from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.modeling.gpt_oss import GptOssExpertsLinear, replace_gpt_oss from llmcompressor.utils.helpers import patch_attr -from accelerate.hooks import add_hook_to_module, remove_hook_from_module, AlignDevicesHook, named_module_tensors, set_module_tensor_to_device, PrefixedDataset -from accelerate.big_modeling import attach_align_device_hook_on_blocks __all__ = ["replace_modules_for_calibration"] @@ -50,11 +51,12 @@ def update_qwen3_moe(model: PreTrainedModel, stack): ) ) + def update_gpt_oss(model: PreTrainedModel, stack): @contextlib.contextmanager def replace(mod_name, module, name, original): hook: AlignDevicesHook = original._hf_hook - + replacement = replace_gpt_oss(model.config, original) replace_offload_module(module, name, hook, replacement) del original @@ -64,18 +66,15 @@ def replace(mod_name, module, name, original): restored = replacement.to_original() delattr(module, name) module.register_module(name, restored) - #replace_offload_module(module, name, hook, restored) + # replace_offload_module(module, name, hook, restored) del replacement - modules = list(model.named_modules()) for name, module in tqdm.tqdm(modules, desc="Converting modules"): for child_name, child in list(module.named_children()): if child.__class__.__name__ == "GptOssExperts": stack.enter_context(replace(name, module, child_name, child)) - - moe_context = { "Qwen3MoeForCausalLM": update_qwen3_moe, @@ -91,8 +90,6 @@ def moe_calibration_context(model: PreTrainedModel, stack): moe_context.get(cls_name)(model, stack) - - def replace_offload_module(base, name: str, hook: AlignDevicesHook, module): delattr(base, name) @@ -123,47 +120,3 @@ def replace_offload_module(base, name: str, hook: AlignDevicesHook, module): ) base.register_module(name, module) - - - # # offloading kwargs for submodule - # place_submodules = False - # offload_buffers = True - - # # copy device offloading arguments from parent - # current_device = next(base.parameters()).device # assume base has parameters - # offload_device = get_offloaded_device(base) - - # # offload parameters to weights map - # for param_name, param in named_module_tensors( - # module, include_buffers=offload_buffers, recurse=place_submodules - # ): - # offloaded = param.to(offload_device) - # if hook.tied_params_map is not None: - # hook.tied_params_map[offloaded.data_ptr()] = {} # (1) - # offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded) - - # # if the parent places submodules, offload here - # if hook.place_submodules: - # set_module_tensor_to_device(module, param_name, current_device) - - # if not hook.place_submodules: - # weights_map = PrefixedDataset( - # hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}." - # ) - - # submodule_hook = AlignDevicesHook( - # execution_device=hook.execution_device, - # offload=hook.offload, - # io_same_device=False, - # weights_map=weights_map, - # offload_buffers=offload_buffers, - # place_submodules=place_submodules, - # skip_keys=None, - # tied_params_map=hook.tied_params_map, - # ) - # add_hook_to_module(module, submodule_hook) - - # base.register_module(name, module) - # for c_name, child in list(module.named_children()): - # register_offload_module(module, c_name, child) - # replace_offload_module(module, None, c_name, child, child) \ No newline at end of file diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index bdc260003..901283252 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -80,33 +80,33 @@ def __call__( if dataset_args.calibrate_moe_context: moe_calibration_context(model, stack) - # # prepare intermediates cache - # activations = IntermediatesCache.from_dataloader(dataloader, model_device) - - # for subgraph_index, subgraph in enumerate(subgraphs): - # # prepare tqdm description texts - # calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" - # prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" - - # # reduce memory movement by keeping modules onloaded - # with disable_offloading(): - # # do a preliminary pass to trigger modifier hooks - # for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc): - # inputs = activations.fetch(batch_idx, subgraph.input_names) - # subgraph.forward(model, **inputs) - - # LifecycleCallbacks.sequential_epoch_end() - - # # this pass does not trigger modifier hooks - # # and is only used for capturing outputs of newly compressed modules - # with HooksMixin.disable_hooks(): - # for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc): - # inputs = activations.fetch(batch_idx, subgraph.input_names) - # output = subgraph.forward(model, **inputs) - - # if subgraph_index < num_subgraphs - 1: - # activations.update(batch_idx, output) - # activations.delete(batch_idx, subgraph.consumed_names) + # prepare intermediates cache + activations = IntermediatesCache.from_dataloader(dataloader, model_device) + + for subgraph_index, subgraph in enumerate(subgraphs): + # prepare tqdm description texts + calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" + prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" + + # reduce memory movement by keeping modules onloaded + with disable_offloading(): + # do a preliminary pass to trigger modifier hooks + for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc): + inputs = activations.fetch(batch_idx, subgraph.input_names) + subgraph.forward(model, **inputs) + + LifecycleCallbacks.sequential_epoch_end() + + # this pass does not trigger modifier hooks + # and is only used for capturing outputs of newly compressed modules + with HooksMixin.disable_hooks(): + for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc): + inputs = activations.fetch(batch_idx, subgraph.input_names) + output = subgraph.forward(model, **inputs) + + if subgraph_index < num_subgraphs - 1: + activations.update(batch_idx, output) + activations.delete(batch_idx, subgraph.consumed_names) # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end() From 5abb3acaad4751412c0bccb6aa820ea3bd55aa75 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 7 Aug 2025 04:55:11 +0000 Subject: [PATCH 3/3] add todo Signed-off-by: Kyle Sayers --- examples/quantization_w4a16/gpt_oss_example.py | 2 +- examples/quantization_w8a8_fp8/fp8_block_example.py | 4 +--- src/llmcompressor/modeling/gpt_oss.py | 2 ++ 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/quantization_w4a16/gpt_oss_example.py b/examples/quantization_w4a16/gpt_oss_example.py index 142f0fbb6..57009cedd 100644 --- a/examples/quantization_w4a16/gpt_oss_example.py +++ b/examples/quantization_w4a16/gpt_oss_example.py @@ -16,7 +16,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 1 # 512 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py index b5d6ca1f9..68f13cf93 100644 --- a/examples/quantization_w8a8_fp8/fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -15,9 +15,7 @@ # In this case, we: # * quantize the weights to fp8 with per channel via ptq # * quantize the activations to fp8 with dynamic per token -recipe = QuantizationModifier( - targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"] -) +recipe = QuantizationModifier(targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]) # Apply quantization. oneshot(model=model, recipe=recipe) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index a8044a3b4..8c28ac41d 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -158,6 +158,8 @@ def to_original(self) -> GptOssExperts: experts.down_proj[expert_index].copy_(expert.down_proj.weight.data.T) experts.down_proj_bias[expert_index].copy_(expert.down_proj.bias.data) + # TODO: convert qparams as well + print("converted, for some reason slows down over time") import time