Skip to content

Commit 9417d5c

Browse files
kylesayrsdbarbuzzirahul-tulikelkelchengdsikka
authored
[Tracing] Code AutoWrapper (#1411)
## Purpose ## * Reduce model support burden by automatically wrapping untraceable code * This is a programmatic implementation of all of the rules specified by the [tracing guide](https://github.com/vllm-project/llm-compressor/blob/0.5.1/src/llmcompressor/transformers/tracing/GUIDE.md) * Remove traceable definitions for `Idefics3ForConditionalGeneration`, `LlavaForConditionalGeneration`, `MllamaForConditionalGeneration`, `LlavaForConditionalGeneration`, `Qwen2_5_VLForConditionalGeneration`, and `Qwen2VLForConditionalGeneration`, `Gemma3ForConditionalGeneration` (all of them) ## Fixes ## * #1457 ## Autowrap Patterns ## These patterns match syntax which is untraceable and unlikely to call sequential targets (either directly or indirectly) <details><summary>If statements whose conditions cannot be statically evaluated</summary> This if statement can be statically evaluated, since its value can be evaluated in the context of `{"self": LlamaModel(...)}` ```python3 if self.config._attn_implementation != "eager": ... ``` If the statement cannot be statically evaluated, then it is wrapped ```python3 torch.fx.wrap def wrapped(input_ids, inputs_embeds): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") def forward(...): (,) = wrapped(input_ids, inputs_embeds) ``` </details> <details><summary>Ignored functions (`_update_causal_mask`)</summary> Any function or method names listed in the ignore list will be automatically wrapped ```python3 torch.fx.wrap def wrapped(attention_mask, inputs_embeds, cache_position, ...): return self._update_causal_mask(attention_mask, inputs_embeds, cache_position, ...) def forward(...): causal_mask = wrapped(attention_mask, inputs_embeds, cache_position, ...) ``` </details> <details><summary>Starred tuple unpacking</summary> Any use of iterated unpacking will be automatically wrapped ```python3 torch.fx.wrap def wrapped(input_shape): return (*input_shape, -1, self.head_dim) def forward(...): hidden_shape = wrapped(input_shape) ``` </details> <details><summary>Starred argument unpacking</summary> Any use of iterated unpacking into variadic args is automatically wrapped ```python3 torch.fx.wrap def wrapped(attn_output, input_shape): return attn_output.reshape(*input_shape, -1) def forward(...): attn_output = wrapped(attn_output, input_shape) ``` </details> ## Autowrap Implementation Details ## <details><summary>Wrapper arguments</summary> Autowrapping a piece of code requires determining which variable names are used by that code and which variable names are produced by that code. This is done using the `NameAnalyzer`, which determines the unbound, assigned, and conditionally assigned names for a given piece of code. ```python3 # unbound := names which are read by node before being assigned # assigned := names which are assigned by operations in node # cond_assigned := names which may be assigned depending on execution analyzer = NameAnalyzer(omit=self.namespace.keys()) unbound, assigned, conditionally_assigned = analyzer.analyze(node) ``` This information is then used to determine what the args, kwargs, and return names should be for the wrapping function. ```python3 # args := names which already existed and are needed for ops or wrapped return # kwargs := names which are needed for return but did not already exist # returns := names which are assigned or could be assigned args = (unbound | conditionally_assigned) & self._local_names kwargs = conditionally_assigned - self._local_names returns = assigned | conditionally_assigned ``` </details> <details><summary>Wrapping methods</summary> Some untraceable code references `self` during execution. While normally, `self` would be an argument to the wrapped function, `self` is usually a `torch.nn.Module` which is not a handled type that can be passed around the graph. Instead, we treat `self` as a variable in the compiled python module namespace, and this namespace is automatically captured and executed by `torch.fx._symbolic_trace` </details> <details><summary>Unwrappable code</summary> Some code cannot be wrapped because it contains control flow statements which must exist in a certain context. For example, we cannot wrap code that contains a `continue` without also wrapping the for loop that surrounds it. ```python3 for index, layer in enumerate(self.layers): # ---- cannot autowrap ---- if index <= 10: continue # ---- cannot autowrap ---- hidden_states = layer(hidden_states) ``` </details> ## Future Extensions/ Improvements ## <details><summary>Sequentially executing vision towers</summary> Sequentially tracing vision towers is a lower priority, as the vision towers are typically fewer parameters and aren't quantization targets. However, in the event that they do become quantization targets, or memory(vision_tower + one target) > memory(one gpu), then the vision tower layers will need to be split up. Some changes may be required to support this. Conditionally executing the vision tower is a very common pattern: ```python3 def forward(pixel_values, image_embeds, ...): if image_embeds is None: image_embeds = self.vision_tower(pixel_values) ... ``` Some approaches might be 1. Allowing names like `image_embeds` to be evaluated based on the sample input being passed 2. Pattern matching against `self.{module_name}()`, where module_name is determined to be a module through evaluation 3. Using type hinting analysis tools like `jedi` to track the types of all names, and to check if any names whose type is a module are called </details> <details><summary>Towards perfect autowrapping</summary> As mentioned in in "Sequentially executing vision towers", it may be possible to use use type hinting analysis tools like `jedi` or `pytype` to infer whether any given code chunk calls a sequential target or target ancestor. If this can be done reliably (which will require extensions such as analysis of called functions), then all code that does not call sequential targets can be wrapped. </details> <details><summary>Towards removing tracing</summary> If we can reliably determine if a given code chunk calls sequential any targets, it may be possible to define each autowrapped function as its own subgraph directly, without the need for tracing. However, this will require inference of model execution from the static code, which means unrolling loops, expanding any function/method calls, and resolving dynamic model structure (in the case of llava and idefics), which may be tricky. For example, if you want to determine the execution of the LlamaDecoder layers of a llava model like pixtral, you'd need to evaluate `self.language_model`, then analyze the source of the caller's forward function, which is stored in a separate file. ```python3 def forward(...): self.language_model(...) # the type of self.language_model is determined from the config ``` Another point of evaluation would be evaling any iteration of ModuleLists ```python3 def forward(...): for decoder_layer of self.layers: # ModuleList isn't well typed and may contain different types decoder_layer(...) ``` The tracing system could be replaced with static code inference, both are different systems for solving the problem of determining model execution </details> ## Testing ## * Able to trace all models in `tests/llmcompressor/transformers/tracing/test_models.py` without requiring traceable definitions * Verified sequentially executed outputs are correct for `LlamaForCausalLM` * Ran `examples/quantization_w4a16/llama3_example.py` to completion --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Domenic Barbuzzi <[email protected]> Signed-off-by: Rahul Tuli <[email protected]> Signed-off-by: Kelvin Cheng <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Domenic Barbuzzi <[email protected]> Co-authored-by: Rahul Tuli <[email protected]> Co-authored-by: Kel <[email protected]> Co-authored-by: Dipika Sikka <[email protected]> Co-authored-by: Vedant <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]>
1 parent 9439f18 commit 9417d5c

35 files changed

+846
-8373
lines changed

examples/multimodal_audio/README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ Sequential targets are the modules which determine the granularity of error prop
4747

4848
Choosing sequential targets with higher granularity (for example "Linear" instead of "LlamaDecoderLayer") will result in fewer hessians being allocated at the same time, decreasing the memory requirements for compression. This may also increase the recovered accuracy of the model, as compression error is propagated at a higher granularity. However, using higher granularity sequential targets may also increase compression time, as more time is spent offloading and onloading activations.
4949

50-
### Ignore ###
51-
If your model is not traceable for your desired dataset, first consider adding any problematic modules to the ignore list. Doing this prevents the model tracer from tracing the internals of those modules, thereby avoid the untraceable operations.
52-
53-
## Tracing Errors ##
54-
Because the architectures of audio-language models is often times more complex than those of typical decoder-only text models, you may encounter `torch.fx.TraceError`s when attempting to quantize your model. For more information on `torch.fx.TraceError`s, why they occur, and how to resolve them, please see the [Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md).
55-
5650
## Adding Your Own Smoothquant Mappings ##
5751
For a guide on adding smoothquant mappings for your dataset, see the [SmoothQuant Guide](/src/llmcompressor/modifiers/smoothquant/README.md).
5852

examples/multimodal_vision/README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ Sequential targets are the modules which determine the granularity of error prop
5151

5252
Choosing sequential targets with higher granularity (for example "Linear" instead of "LlamaDecoderLayer") will result in fewer hessians being allocated at the same time, decreasing the memory requirements for compression. This may also increase the recovered accuracy of the model, as compression error is propagated at a higher granularity. However, using higher granularity sequential targets may also increase compression time, as more time is spent offloading and onloading activations.
5353

54-
### Ignore ###
55-
If your model is not traceable for your desired dataset, first consider adding any problematic modules to the ignore list. Doing this prevents the model tracer from tracing the internals of those modules, thereby avoid the untraceable operations.
56-
57-
## Tracing Errors ##
58-
Because the architectures of vision-language models is often times more complex than those of typical decoder-only text models, you may encounter `torch.fx.TraceError`s when attempting to quantize your model. For more information on `torch.fx.TraceError`s, why they occur, and how to resolve them, please see the [Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md).
59-
6054
## Adding Your Own Smoothquant Mappings ##
6155
For a guide on adding smoothquant mappings for your dataset, see the [SmoothQuant Guide](/src/llmcompressor/modifiers/smoothquant/README.md).
6256

examples/multimodal_vision/gemma3_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor
4+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
8-
from llmcompressor.transformers.tracing import TraceableGemma3ForConditionalGeneration
98

109
# Load model.
1110
model_id = "google/gemma-3-4b-it"
12-
model = TraceableGemma3ForConditionalGeneration.from_pretrained(
11+
model = Gemma3ForConditionalGeneration.from_pretrained(
1312
model_id, device_map="auto", torch_dtype="auto"
1413
)
1514
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

examples/multimodal_vision/idefics3_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
import torch
33
from datasets import load_dataset
44
from PIL import Image
5-
from transformers import AutoProcessor
5+
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
66

77
from llmcompressor import oneshot
88
from llmcompressor.modifiers.quantization import GPTQModifier
9-
from llmcompressor.transformers.tracing import TraceableIdefics3ForConditionalGeneration
109

1110
# Load model.
1211
model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct"
13-
model = TraceableIdefics3ForConditionalGeneration.from_pretrained(
12+
model = Idefics3ForConditionalGeneration.from_pretrained(
1413
model_id, device_map="auto", torch_dtype="auto"
1514
)
1615
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

examples/multimodal_vision/llava_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor
4+
from transformers import AutoProcessor, LlavaForConditionalGeneration
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
8-
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
98

109
# Load model.
1110
model_id = "llava-hf/llava-1.5-7b-hf"
12-
model = TraceableLlavaForConditionalGeneration.from_pretrained(
11+
model = LlavaForConditionalGeneration.from_pretrained(
1312
model_id, device_map="auto", torch_dtype="auto"
1413
)
1514
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

examples/multimodal_vision/mllama_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor
4+
from transformers import AutoProcessor, MllamaForConditionalGeneration
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
8-
from llmcompressor.transformers.tracing import TraceableMllamaForConditionalGeneration
98

109
# Load model.
1110
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
12-
model = TraceableMllamaForConditionalGeneration.from_pretrained(
11+
model = MllamaForConditionalGeneration.from_pretrained(
1312
model_id, device_map="auto", torch_dtype="auto"
1413
)
1514
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

examples/multimodal_vision/pixtral_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor
4+
from transformers import AutoProcessor, LlavaForConditionalGeneration
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
8-
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
98

109
# Load model.
1110
model_id = "mgoin/pixtral-12b"
12-
model = TraceableLlavaForConditionalGeneration.from_pretrained(
11+
model = LlavaForConditionalGeneration.from_pretrained(
1312
model_id, device_map="auto", torch_dtype="auto"
1413
)
1514
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

examples/multimodal_vision/qwen2_vl_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import torch
55
from datasets import load_dataset
66
from qwen_vl_utils import process_vision_info
7-
from transformers import AutoProcessor
7+
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
88

99
from llmcompressor import oneshot
1010
from llmcompressor.modifiers.quantization import GPTQModifier
11-
from llmcompressor.transformers.tracing import TraceableQwen2VLForConditionalGeneration
1211

1312
# Load model.
1413
model_id = "Qwen/Qwen2-VL-2B-Instruct"
15-
model = TraceableQwen2VLForConditionalGeneration.from_pretrained(
14+
model = Qwen2VLForConditionalGeneration.from_pretrained(
1615
model_id,
1716
device_map="auto",
1817
torch_dtype="auto",

examples/multimodal_vision/qwen_2_5_vl_example.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@
44
import torch
55
from datasets import load_dataset
66
from qwen_vl_utils import process_vision_info
7-
from transformers import AutoProcessor
7+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
88

99
from llmcompressor.modifiers.quantization import GPTQModifier
1010
from llmcompressor.transformers import oneshot
11-
from llmcompressor.transformers.tracing import (
12-
TraceableQwen2_5_VLForConditionalGeneration,
13-
)
1411

1512
# Load model.
1613
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
17-
model = TraceableQwen2_5_VLForConditionalGeneration.from_pretrained(
14+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
1815
model_id,
1916
device_map="auto",
2017
torch_dtype="auto",

src/llmcompressor/args/dataset_arguments.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,10 @@ class DatasetArguments(CustomDatasetArguments):
179179
"independent]"
180180
},
181181
)
182+
tracing_ignore: List[str] = field(
183+
default_factory=lambda: ["_update_causal_mask"],
184+
metadata={
185+
"help": "List of functions to ignore during tracing, either "
186+
"{module}.{method_name} or {function_name}"
187+
},
188+
)

0 commit comments

Comments
 (0)