Skip to content

Commit e628e13

Browse files
authored
[Tracing] Onload sequential ancestors (#1690)
## Background ## The idea of tracing is to capture all of the events and operations that happen during a model's execution. However, some operations are too complex to capture in the graph. One of these skipped operations is the onloading/offloading of weights because this operation is too complicated to captured by torch dynamo. ``` parent | \ seq_target weight ``` The problem is that, if we skip onloading/offloading during graph capture, weights which appear in the graph might not be onloaded at execution time. This isn't a problem for weights which are equal to/ high granularity than the sequential targets, since those weights belong to modules which are called by the graph (and therefore trigger hooks). ``` hidden_states = ... ... = self.decoder_layer(hidden_states) # module call triggers onloading hooks ``` ``` hidden_states = ... hidden_states = hidden_states + self.bias # self.bias might not be onloaded # we never captured the onloading hook in the graph ``` ### Big Code ### The GPT Big Code model is the first model we've seen like this, where it contains a parameter directly attached to a module which is a caller (ancestor) of the sequential targets. It has this extra, [directly attached bias parameter](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L416-L418). ``` GPTBigCodeModel | \ GPTBigCodeBlock bias ``` ## Purpose ## * Support models which have call graph ancestors with directly attached parameters ## Changes ## * Change the policy for these kinds of modules. Instead of skipping the tracing of their internals, instead just onload them, as they are unlikely to have a large number of direct parameters. ## Testing ## * Added big code to tracing tests * Ran big code example script to completion <details><summary>example_script.py</summary> ```python3 from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot from datasets import load_dataset # Load model model_stub = "ibm-granite/granite-20b-code-instruct-8k" model_name = model_stub.split("/")[-1] num_samples = 1024 max_seq_len = 8192 tokenizer = AutoTokenizer.from_pretrained(model_stub) model = AutoModelForCausalLM.from_pretrained( model_stub, device_map="auto", torch_dtype="auto", ) def preprocess_fn(example): return {"text": example["text"]} #, add_generation_prompt=False, tokenize=False)} ds = load_dataset("neuralmagic/LLM_compression_calibration", split="train") ds = ds.map(preprocess_fn) def tokenize(sample): return tokenizer( sample["text"], padding=False, truncation=False, add_special_tokens=True, ) ds = ds.map(tokenize, remove_columns=ds.column_names) # Configure the quantization algorithm and scheme recipe = GPTQModifier( targets="Linear", scheme="W4A16", ignore=["lm_head"], sequential_targets=["GPTBigCodeBlock"], dampening_frac=0.01, ) # Apply quantization oneshot( model=model, dataset=ds, recipe=recipe, max_seq_length=max_seq_len, num_calibration_samples=num_samples, ) # Save to disk in compressed-tensors format save_path = model_name + "-quantized.w4a16" model.save_pretrained(save_path) tokenizer.save_pretrained(save_path) print(f"Model and tokenizer saved to: {save_path}") ``` </details> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent dddc2f8 commit e628e13

File tree

3 files changed

+17
-19
lines changed

3 files changed

+17
-19
lines changed

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,9 @@ class AWQMapping:
133133
]
134134

135135
_bloom_mappings = [
136-
AWQMapping(
137-
"re:.*input_layernorm$",
138-
["re:.*query_key_value$"]
139-
),
140-
AWQMapping(
141-
"re:.*post_attention_layernorm$",
142-
["re:.*dense_h_to_4h$"]
143-
),
144-
AWQMapping(
145-
"re:.*gelu_impl$",
146-
["re:.*dense_4h_to_h$"]
147-
),
136+
AWQMapping("re:.*input_layernorm$", ["re:.*query_key_value$"]),
137+
AWQMapping("re:.*post_attention_layernorm$", ["re:.*dense_h_to_4h$"]),
138+
AWQMapping("re:.*gelu_impl$", ["re:.*dense_4h_to_h$"]),
148139
# Note: AutoAWQ excludes this mapping, based on researcher's post in
149140
# https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
150141
# AWQMapping(

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple
66

77
import torch
8+
from accelerate.hooks import remove_hook_from_module
89
from compressed_tensors.quantization import find_name_or_class_matches
910
from compressed_tensors.utils import (
1011
has_offloaded_params,
@@ -177,13 +178,12 @@ def __init__(self, ancestors: Set[Module], offloaded: Set[Module]):
177178

178179
# check unlikely case that ancestors have direct params which are offloaded
179180
offloaded_ancestors = offloaded & ancestors
180-
if offloaded_ancestors:
181-
names = set(module.__class__.__name__ for module in offloaded_ancestors)
181+
for ancestor in offloaded_ancestors:
182+
remove_hook_from_module(ancestor, recurse=False)
183+
self.offloaded.remove(ancestor)
182184
logger.warning(
183-
"The following modules are call graph ancestors of sequential targets,"
184-
f"but also contain offloaded modules: {names}.\n"
185-
"These modules will not be traced, and any sequential target children "
186-
"will be executed jointly, which may lead to OOM errors"
185+
f"Direct parameters attached to {ancestor.__class__.__name__} have "
186+
"been onloaded in order to ensure safe graph capture and execution"
187187
)
188188

189189
def create_arg(self, a: Any) -> Argument:

tests/llmcompressor/transformers/tracing/test_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,14 @@
4242
"text",
4343
[],
4444
),
45-
("RedHatAI/DeepSeek-V3-BF16", AutoModelForCausalLM, None, "text", []),
45+
(
46+
"ibm-granite/granite-20b-code-instruct-8k",
47+
AutoModelForCausalLM,
48+
None,
49+
"text",
50+
[],
51+
),
52+
("unsloth/DeepSeek-R1-0528-BF16", AutoModelForCausalLM, None, "text", []),
4653
# --- vision ---
4754
(
4855
"HuggingFaceM4/Idefics3-8B-Llama3",

0 commit comments

Comments
 (0)