Skip to content

Commit b6a75f6

Browse files
committed
commit
Signed-off-by: Kyle Sayers <[email protected]>
1 parent dddc2f8 commit b6a75f6

File tree

3 files changed

+16
-19
lines changed

3 files changed

+16
-19
lines changed

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,9 @@ class AWQMapping:
133133
]
134134

135135
_bloom_mappings = [
136-
AWQMapping(
137-
"re:.*input_layernorm$",
138-
["re:.*query_key_value$"]
139-
),
140-
AWQMapping(
141-
"re:.*post_attention_layernorm$",
142-
["re:.*dense_h_to_4h$"]
143-
),
144-
AWQMapping(
145-
"re:.*gelu_impl$",
146-
["re:.*dense_4h_to_h$"]
147-
),
136+
AWQMapping("re:.*input_layernorm$", ["re:.*query_key_value$"]),
137+
AWQMapping("re:.*post_attention_layernorm$", ["re:.*dense_h_to_4h$"]),
138+
AWQMapping("re:.*gelu_impl$", ["re:.*dense_4h_to_h$"]),
148139
# Note: AutoAWQ excludes this mapping, based on researcher's post in
149140
# https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
150141
# AWQMapping(

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple
66

77
import torch
8+
from accelerate.hooks import remove_hook_from_module
89
from compressed_tensors.quantization import find_name_or_class_matches
910
from compressed_tensors.utils import (
1011
has_offloaded_params,
@@ -177,13 +178,12 @@ def __init__(self, ancestors: Set[Module], offloaded: Set[Module]):
177178

178179
# check unlikely case that ancestors have direct params which are offloaded
179180
offloaded_ancestors = offloaded & ancestors
180-
if offloaded_ancestors:
181-
names = set(module.__class__.__name__ for module in offloaded_ancestors)
181+
for ancestor in offloaded_ancestors:
182+
remove_hook_from_module(ancestor, recurse=False)
183+
self.offloaded.remove(ancestor)
182184
logger.warning(
183-
"The following modules are call graph ancestors of sequential targets,"
184-
f"but also contain offloaded modules: {names}.\n"
185-
"These modules will not be traced, and any sequential target children "
186-
"will be executed jointly, which may lead to OOM errors"
185+
f"Direct parameters attached to {ancestor.__class__.__name__} have "
186+
"been onloaded in order to ensure safe graph capture and execution"
187187
)
188188

189189
def create_arg(self, a: Any) -> Argument:

tests/llmcompressor/transformers/tracing/test_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@
4242
"text",
4343
[],
4444
),
45-
("RedHatAI/DeepSeek-V3-BF16", AutoModelForCausalLM, None, "text", []),
45+
(
46+
"ibm-granite/granite-20b-code-instruct-8k",
47+
AutoModelForCausalLM,
48+
None,
49+
"text",
50+
[],
51+
)("RedHatAI/DeepSeek-V3-BF16", AutoModelForCausalLM, None, "text", []),
4652
# --- vision ---
4753
(
4854
"HuggingFaceM4/Idefics3-8B-Llama3",

0 commit comments

Comments
 (0)