File tree Expand file tree Collapse file tree 1 file changed +8
-12
lines changed
src/llmcompressor/modifiers/quantization/quantization Expand file tree Collapse file tree 1 file changed +8
-12
lines changed Original file line number Diff line number Diff line change 32
32
run_calibration_forward ,
33
33
)
34
34
from llmcompressor .observers .helpers import get_observer_token_count
35
+ from llmcompressor .utils .helpers import calibration_forward_context
35
36
36
37
__all__ = ["QuantizationModifier" ]
37
38
@@ -309,18 +310,13 @@ def _calibrate(self, module: Module):
309
310
f"{ len (self .calibration_dataloader_ )} samples..."
310
311
)
311
312
312
- module_training = module .training
313
- module .eval ()
314
-
315
- run_calibration_forward (
316
- module ,
317
- self .calibration_dataloader_ ,
318
- self .num_calibration_steps ,
319
- self .calibration_function_ ,
320
- )
321
-
322
- if module_training :
323
- module .train ()
313
+ with calibration_forward_context (module ):
314
+ run_calibration_forward (
315
+ module ,
316
+ self .calibration_dataloader_ ,
317
+ self .num_calibration_steps ,
318
+ self .calibration_function_ ,
319
+ )
324
320
325
321
def _check_token_distribution (
326
322
self , model : Module , threshold : Optional [float ] = None
You can’t perform that action at this time.
0 commit comments