Skip to content

Commit 4f35d48

Browse files
rahul-tulidsikka
andauthored
Fix Multi-Context Manager Syntax for Python 3.9 Compatibility (#1313)
Brings back #1287 since it was reverted (as full tests did not run) Changes: This PR resolves [Issue #1250](#1250), where the codebase relied on Python 3.10+ syntax for multiple context managers in a single `with` statement (using parentheses). **Refactored Affected Modules** - Updated `src/llmcompressor/modifiers/obcq/base.py`, `src/llmcompressor/modifiers/pruning/wanda/base.py`, `src/llmcompressor/modifiers/quantization/gptq/base.py`, `src/llmcompressor/pipelines/sequential/helpers.py`, and `src/llmcompressor/utils/helpers.py` - Modified `tests/llmcompressor/modifiers/utils/test_hooks.py` ### Testing - Verified that all modified modules run correctly under Python 3.9 and 3.10. - Ran the updated test suite (`test_hooks.py`) to confirm hook disabling behavior remains intact. - Confirmed no regressions in sparsification, quantization, or tracing pipelines. ### Related Issues - Fixes [#1250](#1250) --------- Signed-off-by: Rahul Tuli <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 79ff313 commit 4f35d48

File tree

5 files changed

+13
-22
lines changed

5 files changed

+13
-22
lines changed

src/llmcompressor/modifiers/obcq/base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,9 @@ def on_sequential_batch_end(self):
119119
num_samples = self._num_samples[module]
120120

121121
logger.info(f"Sparsifying {name} using {num_samples} samples")
122-
with (
123-
torch.no_grad(),
124-
align_module_device(module),
125-
CompressionLogger(module) as comp_logger,
126-
):
122+
with torch.no_grad(), align_module_device(module), CompressionLogger(
123+
module
124+
) as comp_logger:
127125
loss, sparsified_weight = sparsify_weight(
128126
module=module,
129127
hessians_dict=self._hessians,

src/llmcompressor/modifiers/pruning/wanda/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,8 @@ def on_sequential_batch_end(self):
103103
num_samples = self._num_samples[module]
104104

105105
logger.info(f"Sparsifying {name} using {num_samples} samples")
106-
with (
107-
torch.no_grad(),
108-
align_module_device(module),
109-
CompressionLogger(module),
106+
with torch.no_grad(), align_module_device(module), CompressionLogger(
107+
module
110108
):
111109
sparsified_weight = sparsify_weight(
112110
module=module,

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,11 @@ def on_sequential_batch_end(self):
337337
quant_args = getattr_chain(module, "quantization_scheme.weights")
338338

339339
logger.info(f"Quantizing {name} using {num_samples} samples")
340-
with (
341-
torch.no_grad(),
342-
align_module_device(module),
343-
self._maybe_onload_hessian(module),
344-
CompressionLogger(module) as comp_logger,
345-
):
340+
with torch.no_grad(), align_module_device(
341+
module
342+
), self._maybe_onload_hessian(module), CompressionLogger(
343+
module
344+
) as comp_logger:
346345
loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
347346
module=module,
348347
quant_args=quant_args,

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ def trace_subgraphs(
7171
concrete_args = populate_concrete_args(model, sample_input)
7272

7373
# trace
74-
with (
75-
calibration_forward_context(model),
76-
HooksMixin.disable_hooks(),
77-
):
74+
with calibration_forward_context(model), HooksMixin.disable_hooks():
7875
graph = GraphModule(
7976
model,
8077
tracer.trace(

tests/llmcompressor/modifiers/utils/test_hooks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,8 @@ def test_disable_hooks_composable():
139139
handle_b = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre")
140140

141141
# composing two keeps
142-
with (
143-
HooksMixin.disable_hooks(keep=set([handle_b])),
144-
HooksMixin.disable_hooks(keep=set([handle_a])),
142+
with HooksMixin.disable_hooks(keep=set([handle_b])), HooksMixin.disable_hooks(
143+
keep=set([handle_a])
145144
):
146145
model(model.dummy_inputs)
147146
assert mod_a.hook_called and mod_b.hook_called

0 commit comments

Comments
 (0)