Skip to content

Commit 9765e2b

Browse files
fynnsudsikka
andauthored
Fix negative activation values in awq scale calculation (#1788)
SUMMARY: We're currently missing an `abs()` call when computing activation means. This results in negative mean values, which then interferes with the scales calculation, causing the scales to be set to `nan`. As seen [here](https://github.com/casper-hansen/AutoAWQ/blob/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/quantize/quantizer.py#L331), this `abs()` call existed in the AutoAWQ repo but was lost at some stage. TEST PLAN: After fix, applied AWQ to Qwen3-Coder-30B-A3B-Instruct and evaluated on humaneval and humaneval+ Algorithm | humaneval@1 | humaneval@10 | humaneval+@1 | humaneval+@10 | ------ | ----------- | ------------ | ------------ | ------------- | base | 0.93 | 0.939 | 0.887 | 0.898 | GPTQ | 0.927 | **0.945** | 0.885 | **0.905** | RTN | 0.924 | 0.939 | 0.87 | 0.885 | AWQ | **0.937** | **0.945** | **0.893** | 0.902 | <details> <summary>awq script</summary> ```python from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.awq import AWQModifier MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct" SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-sym-main-abs2-no-duo-scaling" # Configure the quantization algorithm to run. recipe = [ AWQModifier( duo_scaling=False, ignore=[ "lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$", "re:visual.*", ], scheme="W4A16", targets=["Linear"], ), ] # Select calibration dataset. DATASET_ID = "codeparrot/self-instruct-starcoder" DATASET_SPLIT = "curated" # Select number of samples. 256 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 256 MAX_SEQUENCE_LENGTH = 2048 def get_calib_dataset(tokenizer): from datasets import load_dataset ds = load_dataset( DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]", ) def preprocess(example): chat_messages = [ {"role": "user", "content": example["instruction"].strip()}, {"role": "assistant", "content": example["output"].strip()}, ] tokenized_messages = tokenizer.apply_chat_template( chat_messages, tokenize=True ) return {"input_ids": tokenized_messages} ds = ( ds.shuffle(seed=42) .map(preprocess, remove_columns=ds.column_names) .select(range(NUM_CALIBRATION_SAMPLES)) ) return ds if __name__ == "__main__": model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) ### ### Apply algorithms. ### oneshot( model=model, dataset=get_calib_dataset(tokenizer), recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, log_dir=None, trust_remote_code_model=True, ) # Save to disk compressed. model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) ``` </details> <details> <summary> gptq script </summary> ```python from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct" SAVE_DIR = MODEL_ID.split("/")[-1] + "-gptq-sym" # Configure the quantization algorithm to run. recipe = [ GPTQModifier( ignore=[ "lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$", "re:visual.*", ], scheme="W4A16", targets=["Linear"], ), ] # Select calibration dataset. DATASET_ID = "codeparrot/self-instruct-starcoder" DATASET_SPLIT = "curated" # Select number of samples. 256 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 256 MAX_SEQUENCE_LENGTH = 2048 def get_calib_dataset(tokenizer): from datasets import load_dataset ds = load_dataset( DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]", ) def preprocess(example): chat_messages = [ {"role": "user", "content": example["instruction"].strip()}, {"role": "assistant", "content": example["output"].strip()}, ] tokenized_messages = tokenizer.apply_chat_template( chat_messages, tokenize=True ) return {"input_ids": tokenized_messages} ds = ( ds.shuffle(seed=42) .map(preprocess, remove_columns=ds.column_names) .select(range(NUM_CALIBRATION_SAMPLES)) ) return ds if __name__ == "__main__": model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) ### ### Apply algorithms. ### oneshot( model=model, dataset=get_calib_dataset(tokenizer), recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, log_dir=None, trust_remote_code_model=True, ) # Save to disk compressed. model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) ``` </details> <details> <summary> rtn script </summary> ```python from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct" SAVE_DIR = MODEL_ID.split("/")[-1] + "-rtn-sym" # Configure the quantization algorithm to run. recipe = [ QuantizationModifier( ignore=[ "lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$", "re:visual.*", ], scheme="W4A16", targets=["Linear"], ), ] # Select calibration dataset. DATASET_ID = "codeparrot/self-instruct-starcoder" DATASET_SPLIT = "curated" # Select number of samples. 256 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 256 MAX_SEQUENCE_LENGTH = 2048 def get_calib_dataset(tokenizer): from datasets import load_dataset ds = load_dataset( DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]", ) def preprocess(example): chat_messages = [ {"role": "user", "content": example["instruction"].strip()}, {"role": "assistant", "content": example["output"].strip()}, ] tokenized_messages = tokenizer.apply_chat_template( chat_messages, tokenize=True ) return {"input_ids": tokenized_messages} ds = ( ds.shuffle(seed=42) .map(preprocess, remove_columns=ds.column_names) .select(range(NUM_CALIBRATION_SAMPLES)) ) return ds if __name__ == "__main__": model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) ### ### Apply algorithms. ### oneshot( model=model, dataset=get_calib_dataset(tokenizer), recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, log_dir=None, trust_remote_code_model=True, ) # Save to disk compressed. model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) ``` </details> --------- Signed-off-by: Fynn Schmitt-Ulms <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 4caf540 commit 9765e2b

File tree

1 file changed

+1
-1
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+1
-1
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def cache_smooth_activations_hook(
397397
):
398398
self._smooth_activation_means[smooth_name] = _accumulate_mean(
399399
# Assume that first argument is the input
400-
args[0].cpu().detach().squeeze(),
400+
args[0].cpu().abs().detach().squeeze(),
401401
self._smooth_activation_means.get(smooth_name, None),
402402
)
403403

0 commit comments

Comments
 (0)