diff --git a/examples/quantization_w4a16/gpt_oss_example.py b/examples/quantization_w4a16/gpt_oss_example.py new file mode 100644 index 000000000..57009cedd --- /dev/null +++ b/examples/quantization_w4a16/gpt_oss_example.py @@ -0,0 +1,80 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +model_id = "unsloth/gpt-oss-20b-BF16" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + calibrate_moe_context=True, + pipeline="sequential", +) + +# # Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py index b5d6ca1f9..68f13cf93 100644 --- a/examples/quantization_w8a8_fp8/fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -15,9 +15,7 @@ # In this case, we: # * quantize the weights to fp8 with per channel via ptq # * quantize the activations to fp8 with dynamic per token -recipe = QuantizationModifier( - targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"] -) +recipe = QuantizationModifier(targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]) # Apply quantization. oneshot(model=model, recipe=recipe) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py new file mode 100644 index 000000000..8c28ac41d --- /dev/null +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -0,0 +1,255 @@ +# flake8: noqa +from typing import List + +import torch +from compressed_tensors.utils import align_module_device, update_offload_parameter +from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts + +from llmcompressor.utils.dev import skip_weights_initialize + + +class GptOssExpert(torch.nn.Module): + gate_proj: torch.nn.Linear + up_proj: torch.nn.Linear + down_proj: torch.nn.Linear + + def __init__(self, experts: GptOssExperts): + super().__init__() + + self.hidden_size = experts.hidden_size + self.expert_dim = experts.expert_dim + self.alpha = experts.alpha + self.limit = experts.limit + + assert experts.gate_up_proj.dtype == experts.gate_up_proj_bias.dtype + assert experts.down_proj.dtype == experts.down_proj_bias.dtype + + with skip_weights_initialize(): + self.gate_proj = torch.nn.Linear( + self.hidden_size, + self.expert_dim, + bias=True, + dtype=experts.gate_up_proj.dtype, + ) + self.up_proj = torch.nn.Linear( + self.hidden_size, + self.expert_dim, + bias=True, + dtype=experts.gate_up_proj.dtype, + ) + self.down_proj = torch.nn.Linear( + self.expert_dim, + self.hidden_size, + bias=True, + dtype=experts.down_proj.dtype, + ) + + def forward(self, hidden_states: torch.Tensor): + gate = self.gate_proj(hidden_states) + gate = gate.clamp(min=None, max=self.limit) + + up = self.up_proj(hidden_states) + up = up.clamp(min=-self.limit, max=self.limit) + + glu = gate * torch.sigmoid(gate * self.alpha) + return self.down_proj((up + 1) * glu) + + +class GptOssExpertsLinear(torch.nn.Module): + experts: List[GptOssExpert] + + def __init__(self, experts: GptOssExperts): + super().__init__() + + self.intermediate_size = experts.intermediate_size + self.num_experts = experts.num_experts + self.hidden_size = experts.hidden_size + self.expert_dim = experts.expert_dim + + with skip_weights_initialize(): + self.experts = torch.nn.ModuleList( + [GptOssExpert(experts) for _ in range(self.num_experts)] + ) + + self.load_weights(experts) + + self.alpha = experts.alpha + self.limit = experts.limit + + def load_weights(self, experts: GptOssExperts): + with align_module_device(experts): + for expert_index, expert in enumerate(self.experts): + update_offload_parameter( + expert.gate_proj, + "weight", + experts.gate_up_proj[expert_index, ..., ::2].T, + ) + update_offload_parameter( + expert.gate_proj, + "bias", + experts.gate_up_proj_bias[expert_index, ..., ::2], + ) + + update_offload_parameter( + expert.up_proj, + "weight", + experts.gate_up_proj[expert_index, ..., 1::2].T, + ) + update_offload_parameter( + expert.up_proj, + "bias", + experts.gate_up_proj_bias[expert_index, ..., 1::2], + ) + + update_offload_parameter( + expert.down_proj, "weight", experts.down_proj[expert_index].T + ) + update_offload_parameter( + expert.down_proj, "bias", experts.down_proj_bias[expert_index] + ) + + def to_original(self) -> GptOssExperts: + # TODO: this doesn't really handle offloading or correct device placement + with skip_weights_initialize(use_zeros=True): + fake_config = GptOssConfig( + intermediate_size=self.intermediate_size, + num_local_experts=self.num_experts, + hidden_size=self.hidden_size, + ) + experts = GptOssExperts(fake_config) + experts.gate_up_proj = torch.nn.Parameter( + experts.gate_up_proj.to(dtype=self.experts[0].gate_proj.weight.dtype), + requires_grad=False, + ) + experts.gate_up_proj_bias = torch.nn.Parameter( + experts.gate_up_proj_bias.to( + dtype=self.experts[0].gate_proj.weight.dtype + ), + requires_grad=False, + ) + experts.down_proj = torch.nn.Parameter( + experts.down_proj.to(dtype=self.experts[0].down_proj.weight.dtype), + requires_grad=False, + ) + experts.down_proj_bias = torch.nn.Parameter( + experts.down_proj_bias.to(dtype=self.experts[0].down_proj.weight.dtype), + requires_grad=False, + ) + + for expert_index, expert in enumerate(self.experts): + with align_module_device(expert.gate_proj, "cpu"), align_module_device( + expert.up_proj, "cpu" + ), align_module_device(expert.down_proj, "cpu"): + experts.gate_up_proj[expert_index, ..., ::2].copy_( + expert.gate_proj.weight.data.T + ) + experts.gate_up_proj_bias[expert_index, ..., ::2].copy_( + expert.gate_proj.bias.data + ) + + experts.gate_up_proj[expert_index, ..., 1::2].copy_( + expert.up_proj.weight.data.T + ) + experts.gate_up_proj_bias[expert_index, ..., 1::2].copy_( + expert.up_proj.bias.data + ) + + experts.down_proj[expert_index].copy_(expert.down_proj.weight.data.T) + experts.down_proj_bias[expert_index].copy_(expert.down_proj.bias.data) + + # TODO: convert qparams as well + + print("converted, for some reason slows down over time") + import time + + print(time.time()) + + experts.eval() + return experts + + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + """ + When training is is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + Returns: + torch.Tensor + """ + original_shape = hidden_states.shape + hidden_states = hidden_states.reshape( + -1, self.hidden_size + ) # (num_tokens, hidden_size) + + next_states = torch.zeros_like( + hidden_states, dtype=hidden_states.dtype, device=hidden_states.device + ) + for expert_index, expert in enumerate(self.experts): + next_states += expert(hidden_states) * routing_weights.T[ + expert_index + ].unsqueeze(-1) + + next_states = next_states.reshape(original_shape) + return next_states + + +def replace_gpt_oss(config: GptOssConfig, module: GptOssExpert): + return GptOssExpertsLinear(module) + + +def test_restore(): + config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5) + + original = GptOssExperts(config) + linear = GptOssExpertsLinear(original) + + restored = linear.to_original() + for param_name, param in original.named_parameters(recurse=False): + restored_param = getattr(restored, param_name) + assert param.shape == restored_param.shape + assert param.dtype == restored_param.dtype + + assert torch.all(getattr(restored, param_name) == param) + + +def test_correctness(): + batch_size, seq_len = 13, 12 + config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5) + + input = torch.rand((batch_size, seq_len, config.hidden_size)) + routing_weights = torch.rand((batch_size * seq_len, config.num_local_experts)) + + with torch.no_grad(): + original = GptOssExperts(config) + for name in [ + "gate_up_proj", + "gate_up_proj_bias", + "down_proj", + "down_proj_bias", + ]: + setattr(original, name, getattr(original, name).normal_()) + + original.eval() + assert not original.training + true_output = original(input, routing_weights=routing_weights) + + linear = GptOssExpertsLinear(original) + output = linear(input, routing_weights=routing_weights) + + assert torch.allclose(output, true_output, atol=1e-3, rtol=0.0) + + restored = linear.to_original() + restored_output = restored(input, routing_weights=routing_weights) + assert torch.allclose(restored_output, true_output, atol=1e-3, rtol=0.0) + + +if __name__ == "__main__": + test_restore() diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index cb61f5fad..86cd66cb0 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,7 +1,14 @@ +import contextlib + +import tqdm +from accelerate.big_modeling import attach_align_device_hook_on_blocks +from accelerate.hooks import AlignDevicesHook, named_module_tensors from compressed_tensors.utils import replace_module +from compressed_tensors.utils.offload import offload_to_weights_map from transformers import PreTrainedModel from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 +from llmcompressor.modeling.gpt_oss import replace_gpt_oss from llmcompressor.modeling.llama4 import replace as replace_llama4 from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE from llmcompressor.utils.helpers import patch_attr @@ -12,11 +19,13 @@ replacements = { "DeepseekV3MoE": replace_deepseekv3, "Llama4TextMoe": replace_llama4, + "GptOssExperts": replace_gpt_oss, } def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: - for name, module in model.named_modules(): + modules = list(model.named_modules()) + for name, module in tqdm.tqdm(modules, desc="Converting modules"): cls_name = module.__class__.__name__ if cls_name in replacements: new_module = replacements[cls_name](config=model.config, module=module) @@ -28,8 +37,9 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: # ------------------- module replacements; during calibration -------------------- -def update_qwen3_moe(model, stack): - for module in model.modules(): +def update_qwen3_moe(model: PreTrainedModel, stack): + modules = list(model.modules()) + for module in tqdm.tqdm(modules, desc="Converting modules"): cls_name = module.__class__.__name__ if cls_name == "Qwen3MoeDecoderLayer": # Optionally update the model.config to pass in other arguments @@ -42,8 +52,33 @@ def update_qwen3_moe(model, stack): ) +def update_gpt_oss(model: PreTrainedModel, stack): + @contextlib.contextmanager + def replace(mod_name, module, name, original): + hook: AlignDevicesHook = original._hf_hook + + replacement = replace_gpt_oss(model.config, original) + replace_offload_module(module, name, hook, replacement) + del original + + yield + + restored = replacement.to_original() + delattr(module, name) + module.register_module(name, restored) + # replace_offload_module(module, name, hook, restored) + del replacement + + modules = list(model.named_modules()) + for name, module in tqdm.tqdm(modules, desc="Converting modules"): + for child_name, child in list(module.named_children()): + if child.__class__.__name__ == "GptOssExperts": + stack.enter_context(replace(name, module, child_name, child)) + + moe_context = { "Qwen3MoeForCausalLM": update_qwen3_moe, + "GptOssForCausalLM": update_gpt_oss, } @@ -53,3 +88,35 @@ def moe_calibration_context(model: PreTrainedModel, stack): cls_name = model.__class__.__name__ if cls_name in moe_context: moe_context.get(cls_name)(model, stack) + + +def replace_offload_module(base, name: str, hook: AlignDevicesHook, module): + delattr(base, name) + + assert hook.offload + assert hook.weights_map is not None + + # offload parameters to weights map + offload_device = "cpu" + for param_name, param in named_module_tensors( + module, include_buffers=hook.offload_buffers, recurse=True + ): + offloaded = param.to(offload_device) + if hook.tied_params_map is not None: + hook.tied_params_map[offloaded.data_ptr()] = {} # (1) + offload_to_weights_map(hook.weights_map, param_name, offloaded) + + # attach hooks and offload weights + attach_align_device_hook_on_blocks( + module, + hook.execution_device, + hook.offload, + hook.weights_map, + hook.offload_buffers, + "", + hook.skip_keys, + None, + hook.tied_params_map, + ) + + base.register_module(name, module)