Skip to content

Commit dc063de

Browse files
kylesayrsdsikka
andauthored
Disable kernels during calibration (and tracing) (#1454)
## Purpose ## * Guarantee that module hooks trigger by disabling kernel acceleration ## Background ## * As of `transformers>=4.52`, model forward functions may be overwritten with kernels * Kernel execution can be disabled using the config ([source](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/hub_kernels.py#L84)) * It seems that HF wants to continue enabling regular execution through this disabling feature. This gives us some faith that they will not [override the cls init function](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/hub_kernels.py#L78) in a way that inhibits execution with the standard forward definition * This only affects users who have the hf `kernels` library installed ## Changes ## * Implement `disable_hf_kernels` context and add to `calibration_forward_context` * Remove parenthesis around `calibration_forward_context` with in order to support python3.9 * Apply style to tests Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 7c22d86 commit dc063de

File tree

2 files changed

+35
-23
lines changed

2 files changed

+35
-23
lines changed

src/llmcompressor/utils/helpers.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
"eval_context",
6767
"calibration_forward_context",
6868
"patch_attr",
69+
"disable_hf_kernels",
6970
]
7071

7172

@@ -1024,6 +1025,9 @@ def DisableQuantization(module: torch.nn.Module):
10241025

10251026
@contextlib.contextmanager
10261027
def eval_context(module: torch.nn.Module):
1028+
"""
1029+
Disable pytorch training mode for the given module
1030+
"""
10271031
restore_value = module.training
10281032
try:
10291033
module.train(False) # equivalent to eval()
@@ -1033,6 +1037,21 @@ def eval_context(module: torch.nn.Module):
10331037
module.train(restore_value)
10341038

10351039

1040+
@contextlib.contextmanager
1041+
def disable_hf_kernels(model: PreTrainedModel):
1042+
"""
1043+
In transformers>=4.50.0, some module forward methods may be
1044+
replaced by calls to hf hub kernels. This has the potential
1045+
to bypass hooks added by LLM Compressor
1046+
"""
1047+
if hasattr(model, "config"):
1048+
with patch_attr(model.config, "disable_custom_kernels", True):
1049+
yield
1050+
1051+
else:
1052+
yield
1053+
1054+
10361055
@contextlib.contextmanager
10371056
def calibration_forward_context(model: PreTrainedModel):
10381057
"""
@@ -1041,12 +1060,11 @@ def calibration_forward_context(model: PreTrainedModel):
10411060
- Remove gradient calculations
10421061
- Disable the KV cache
10431062
- Disable train mode and enable eval mode
1063+
- Disable hf kernels which could bypass hooks
10441064
"""
1045-
with (
1046-
torch.no_grad(),
1047-
DisableKVCache(model),
1048-
eval_context(model),
1049-
):
1065+
with torch.no_grad(), DisableKVCache(model), eval_context(
1066+
model
1067+
), disable_hf_kernels(model):
10501068
yield
10511069

10521070

tests/examples/test_quantization_2of4_sparse_w4a16.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path):
3636
readme = ReadMe(readme_path)
3737

3838
command = readme.get_code_block_content(position=2, lang="shell")
39-
assert command.startswith("python"), (
40-
"Expected shell command to start with 'python'"
41-
)
39+
assert command.startswith(
40+
"python"
41+
), "Expected shell command to start with 'python'"
4242

4343
command = shlex.split(command)
4444
result = copy_and_run_command(tmp_path, example_dir, command)
@@ -62,18 +62,16 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path):
6262
}
6363

6464
for stage, stage_info in stages.items():
65-
stage_path = (
66-
tmp_path / example_dir / output_dir / stage_info["path"]
67-
)
65+
stage_path = tmp_path / example_dir / output_dir / stage_info["path"]
6866
recipe_path = stage_path / "recipe.yaml"
6967
config_path = stage_path / "config.json"
7068

71-
assert recipe_path.exists(), (
72-
f"Missing recipe file in {stage}: {recipe_path}"
73-
)
74-
assert config_path.exists(), (
75-
f"Missing config file in {stage}: {config_path}"
76-
)
69+
assert (
70+
recipe_path.exists()
71+
), f"Missing recipe file in {stage}: {recipe_path}"
72+
assert (
73+
config_path.exists()
74+
), f"Missing config file in {stage}: {config_path}"
7775

7876
config = AutoConfig.from_pretrained(stage_path)
7977
assert config is not None, f"Failed to load config in {stage}"
@@ -82,13 +80,9 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path):
8280
if stage == "quantization":
8381
actual_format = quant_config.get("format")
8482
else:
85-
actual_format = quant_config.get(
86-
"sparsity_config", {}
87-
).get("format")
83+
actual_format = quant_config.get("sparsity_config", {}).get("format")
8884

89-
assert actual_format, (
90-
f"Missing expected format field in {stage} config"
91-
)
85+
assert actual_format, f"Missing expected format field in {stage} config"
9286
assert actual_format == stage_info["format"], (
9387
f"Unexpected format in {stage}: got '{actual_format}', "
9488
f"expected '{stage_info['format']}'"

0 commit comments

Comments
 (0)