Skip to content

Commit 2a1aa1b

Browse files
committed
sq: remove empty_cache, add calib context
Signed-off-by: Kyle Sayers <[email protected]>
1 parent fe67d7e commit 2a1aa1b

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
1616
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
17+
from llmcompressor.utils.helpers import calibration_forward_context
1718
from llmcompressor.utils.pytorch.module import (
1819
get_layers,
1920
get_matching_layer,
@@ -250,12 +251,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
250251
" CompressionSession to run the SmoothQuant modifier"
251252
)
252253

253-
run_calibration_forward(
254-
model,
255-
calibration_dataloader,
256-
self.num_calibration_steps,
257-
self.calibration_function,
258-
)
254+
with calibration_forward_context(model):
255+
run_calibration_forward(
256+
model,
257+
calibration_dataloader,
258+
self.num_calibration_steps,
259+
self.calibration_function,
260+
)
259261

260262
# remove the hooks now that we are done calibrating
261263
self.remove_hooks()
@@ -313,9 +315,6 @@ def smooth(module):
313315
smooth(layer)
314316
smooth(smooth_layer)
315317

316-
# clear out allocated smoothing scales
317-
torch.cuda.empty_cache()
318-
319318
def _calculate_smoothing_scales(
320319
self, balance_layers: List[Module], activation_scales: torch.Tensor
321320
) -> List[float]:

src/llmcompressor/modifiers/utils/pytorch_helpers.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,6 @@ def run_calibration_forward(
106106
# move on to next calibration sample
107107
intermediates.append((e.args, e.kwargs))
108108

109-
# TODO: not ideal, figure out where we aren't freeing memory instead
110-
# currently without this we run OOM on the 2nd forward pass
111-
torch.cuda.empty_cache()
112-
113109
return intermediates
114110

115111

0 commit comments

Comments
 (0)