Skip to content

Commit 7a8f569

Browse files
committed
qm: use calib context
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 2a1aa1b commit 7a8f569

File tree

1 file changed

+8
-12
lines changed
  • src/llmcompressor/modifiers/quantization/quantization

1 file changed

+8
-12
lines changed

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
run_calibration_forward,
3333
)
3434
from llmcompressor.observers.helpers import get_observer_token_count
35+
from llmcompressor.utils.helpers import calibration_forward_context
3536

3637
__all__ = ["QuantizationModifier"]
3738

@@ -309,18 +310,13 @@ def _calibrate(self, module: Module):
309310
f"{len(self.calibration_dataloader_)} samples..."
310311
)
311312

312-
module_training = module.training
313-
module.eval()
314-
315-
run_calibration_forward(
316-
module,
317-
self.calibration_dataloader_,
318-
self.num_calibration_steps,
319-
self.calibration_function_,
320-
)
321-
322-
if module_training:
323-
module.train()
313+
with calibration_forward_context(module):
314+
run_calibration_forward(
315+
module,
316+
self.calibration_dataloader_,
317+
self.num_calibration_steps,
318+
self.calibration_function_,
319+
)
324320

325321
def _check_token_distribution(
326322
self, model: Module, threshold: Optional[float] = None

0 commit comments

Comments
 (0)