Skip to content

Commit ad71c5b

Browse files
committed
cleanup and nits
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8351ac9 commit ad71c5b

File tree

5 files changed

+10
-39
lines changed

5 files changed

+10
-39
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,6 @@ def __init__(
128128

129129
# offload to cpu if possible
130130
if "cuda" in str(model_args.oneshot_device) and torch.cuda.is_available():
131-
# TODO: consider renaming function similar to "offload_dispatch_model"
132-
# TODO: modify function to remove any hooks if they already exist (making
133-
# sure to move to cpu when removing hook
134-
# TODO: remove hook in util
135131
remove_hook_from_module(model_args.model, recurse=True)
136132
force_cpu_offload(model_args.model, model_args.oneshot_device)
137133
else:

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
)
1717
from llmcompressor.pipelines.registry import CalibrationPipeline
1818
from llmcompressor.pipelines.sequential.helpers import (
19+
disable_offloading,
1920
get_targets_from_modifiers,
20-
keep_onload_context,
2121
)
2222
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
2323

@@ -88,7 +88,7 @@ def __call__(
8888
prop_desc = f"({layer_index + 1}/{num_layers}): Propagating"
8989

9090
# reduce memory movement by keeping modules onloaded
91-
with keep_onload_context():
91+
with disable_offloading():
9292
# do a preliminary pass to trigger modifier hooks
9393
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
9494
inputs = intermediates.fetch(batch_idx)

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"trace_subgraphs",
2929
"Subgraph",
3030
"get_targets_from_modifiers",
31-
"keep_onload_context",
31+
"disable_offloading",
3232
]
3333

3434

@@ -494,7 +494,11 @@ def is_ancestor(module: Module) -> bool:
494494

495495

496496
@contextlib.contextmanager
497-
def keep_onload_context():
497+
def disable_offloading():
498+
"""
499+
Keep modules onloaded and disable offloading until this context exits.
500+
Affects modules which have been hooked with accelerate's `AlignDevicesHook`
501+
"""
498502
original_pre_forward = AlignDevicesHook.pre_forward
499503
onloaded_modules = dict()
500504

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from llmcompressor.pipelines.cache import IntermediatesCache
1212
from llmcompressor.pipelines.registry import CalibrationPipeline
1313
from llmcompressor.pipelines.sequential.helpers import (
14+
disable_offloading,
1415
get_targets_from_modifiers,
15-
keep_onload_context,
1616
trace_subgraphs,
1717
)
1818
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
@@ -86,7 +86,7 @@ def __call__(
8686
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
8787

8888
# reduce memory movement by keeping modules onloaded
89-
with keep_onload_context():
89+
with disable_offloading():
9090
# do a preliminary pass to trigger modifier hooks
9191
for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc):
9292
inputs = activations.fetch(batch_idx, subgraph.input_names)

src/llmcompressor/utils/module.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)