Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bccc54b
Add Gemma3 model tracing support
kelkelcheng Apr 23, 2025
2877679
Refactor Gemma3 tracing imports and update function signatures for cl…
kelkelcheng Apr 23, 2025
0e322b6
[Tracing] Better runtime error messages (#1307)
kylesayrs Apr 23, 2025
b5be87f
Remove debug print statement
kelkelcheng Apr 27, 2025
0683e1e
[Tests] Fix test case; update structure (#1375)
dsikka Apr 23, 2025
7c0771b
fix: Make Recipe.model_dump() output compatible with model_validate()…
ved1beta Apr 23, 2025
11c0e9c
Add: documentation for enhanced `save_pretrained` parameters (#1377)
rahul-tuli Apr 23, 2025
353174c
Revert "fix: Make Recipe.model_dump() output compatible .... (#1378)
rahul-tuli Apr 24, 2025
1aebb7a
AWQ resolved mappings -- ensure shapes align (#1372)
brian-dellabetta Apr 24, 2025
cc665ac
Update w4a16_actorder_weight.yaml lmeval config (#1380)
dbarbuzzi Apr 24, 2025
82db147
[WIP] Add AWQ Asym e2e test case (#1374)
dsikka Apr 24, 2025
d3c0f0a
Bump version; set ct version (#1381)
dsikka Apr 25, 2025
addef4e
bugfix AWQ with Llama models and python 3.9 (#1384)
brian-dellabetta Apr 25, 2025
773fe7f
awq -- hotfix to missing kwargs (#1395)
brian-dellabetta Apr 28, 2025
11138ba
Exclude images from package (#1397)
kylesayrs Apr 29, 2025
8b9f7c4
add gemma3 example
kylesayrs Apr 29, 2025
ea481be
Merge remote-tracking branch 'origin' into kylesayrs/gemma3-example
kylesayrs Apr 29, 2025
81c5799
add back labels check
kylesayrs Apr 29, 2025
0da3931
Merge remote-tracking branch 'origin' into kylesayrs/gemma3-example
kylesayrs Apr 29, 2025
c236f32
Merge branch 'kylesayrs/gemma3-example' into kc/gemma-3-tracing-support
kelkelcheng Apr 29, 2025
94e07cc
Merge branch 'main' into kc/gemma-3-tracing-support
kylesayrs May 2, 2025
61cc2f9
Merge branch 'main' into kc/gemma-3-tracing-support
kylesayrs May 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import inspect
from collections import deque
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union

from compressed_tensors import has_offloaded_params
from compressed_tensors.quantization import find_name_or_class_matches
from torch.fx import Graph, GraphModule, Node
from torch.fx.graph import PythonCode
from torch.fx.proxy import Argument
from torch.nn import Module
from transformers import PreTrainedModel
Expand All @@ -32,16 +33,33 @@ class Subgraph:
graph: Graph
input_names: Set[str]
consumed_names: Set[str]
_code: Optional[PythonCode] = None

def compile_forward(self) -> Callable[[Any], Any]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
"""
Generate and compile code for executing this subgraph
Execute the operations within the subgraph

:return: function which, when called, executes this subgraph
:param \\*args: argument inputs to subgraph forward function
:param \\**kwargs: keyword inputs to subgraph forward function
:return keyword outputs of subgraph forward function (non-consumed variables):
"""
code = self.graph.python_code("self")
exec(code.src, code.globals)
return code.globals.get("forward")
if self._code is None:
self._code = self.graph.python_code("self")
exec(self._code.src, self._code.globals)

forward_fn = self._code.globals.get("forward")

try:
outputs = forward_fn(*args, **kwargs)
except Exception as exception:
raise RuntimeError(
"Raised an exception during execution of the following code:\n"
f"```\n{add_line_numbers(self._code.src)}\n```\n"
"This is likely due to a violation of shape assumptions made when "
"tracing"
) from exception

return outputs


def trace_subgraphs(
Expand Down Expand Up @@ -376,3 +394,9 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
for name, module in model.named_modules()
if find_name_or_class_matches(name, module, target_names)
)


def add_line_numbers(text: str) -> str:
lines = text.splitlines()
numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)]
return "\n".join(numbered_lines)
7 changes: 2 additions & 5 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,10 @@ def run_pipeline(
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"

# compile subgraph forward function
forward_function = subgraph.compile_forward()

# do an preliminary pass to trigger modifier hooks
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
inputs = intermediates.fetch(batch_index, subgraph.input_names)
forward_function(model, **inputs)
subgraph.forward(model, **inputs)

# TODO: replace with a lifecycle event
if callback_modifier:
Expand All @@ -78,7 +75,7 @@ def run_pipeline(
with HooksMixin.disable_hooks():
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
inputs = intermediates.fetch(batch_index, subgraph.input_names)
output = forward_function(model, **inputs)
output = subgraph.forward(model, **inputs)

if subgraph_index < num_subgraphs - 1:
intermediates.update(batch_index, output)
Expand Down
6 changes: 5 additions & 1 deletion src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .gemma3 import (
Gemma3ForConditionalGeneration as TraceableGemma3ForConditionalGeneration,
)
from .llava import (
LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration,
)
Expand All @@ -11,12 +14,13 @@
Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration,
)
from .qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration as TraceableQwen2_5_VLForConditionalGeneration
Qwen2_5_VLForConditionalGeneration as TraceableQwen2_5_VLForConditionalGeneration,
)
from .debug import get_model_class

__all__ = [
"get_model_class",
"TraceableGemma3ForConditionalGeneration",
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableQwen2VLForConditionalGeneration",
Expand Down
Loading