File tree Expand file tree Collapse file tree 5 files changed +10
-39
lines changed Expand file tree Collapse file tree 5 files changed +10
-39
lines changed Original file line number Diff line number Diff line change @@ -128,10 +128,6 @@ def __init__(
128
128
129
129
# offload to cpu if possible
130
130
if "cuda" in str (model_args .oneshot_device ) and torch .cuda .is_available ():
131
- # TODO: consider renaming function similar to "offload_dispatch_model"
132
- # TODO: modify function to remove any hooks if they already exist (making
133
- # sure to move to cpu when removing hook
134
- # TODO: remove hook in util
135
131
remove_hook_from_module (model_args .model , recurse = True )
136
132
force_cpu_offload (model_args .model , model_args .oneshot_device )
137
133
else :
Original file line number Diff line number Diff line change 16
16
)
17
17
from llmcompressor .pipelines .registry import CalibrationPipeline
18
18
from llmcompressor .pipelines .sequential .helpers import (
19
+ disable_offloading ,
19
20
get_targets_from_modifiers ,
20
- keep_onload_context ,
21
21
)
22
22
from llmcompressor .utils .helpers import DisableQuantization , calibration_forward_context
23
23
@@ -88,7 +88,7 @@ def __call__(
88
88
prop_desc = f"({ layer_index + 1 } /{ num_layers } ): Propagating"
89
89
90
90
# reduce memory movement by keeping modules onloaded
91
- with keep_onload_context ():
91
+ with disable_offloading ():
92
92
# do a preliminary pass to trigger modifier hooks
93
93
for batch_idx in tqdm .tqdm (range (len (dataloader )), desc = calib_desc ):
94
94
inputs = intermediates .fetch (batch_idx )
Original file line number Diff line number Diff line change 28
28
"trace_subgraphs" ,
29
29
"Subgraph" ,
30
30
"get_targets_from_modifiers" ,
31
- "keep_onload_context " ,
31
+ "disable_offloading " ,
32
32
]
33
33
34
34
@@ -494,7 +494,11 @@ def is_ancestor(module: Module) -> bool:
494
494
495
495
496
496
@contextlib .contextmanager
497
- def keep_onload_context ():
497
+ def disable_offloading ():
498
+ """
499
+ Keep modules onloaded and disable offloading until this context exits.
500
+ Affects modules which have been hooked with accelerate's `AlignDevicesHook`
501
+ """
498
502
original_pre_forward = AlignDevicesHook .pre_forward
499
503
onloaded_modules = dict ()
500
504
Original file line number Diff line number Diff line change 11
11
from llmcompressor .pipelines .cache import IntermediatesCache
12
12
from llmcompressor .pipelines .registry import CalibrationPipeline
13
13
from llmcompressor .pipelines .sequential .helpers import (
14
+ disable_offloading ,
14
15
get_targets_from_modifiers ,
15
- keep_onload_context ,
16
16
trace_subgraphs ,
17
17
)
18
18
from llmcompressor .utils .helpers import DisableQuantization , calibration_forward_context
@@ -86,7 +86,7 @@ def __call__(
86
86
prop_desc = f"({ subgraph_index + 1 } /{ num_subgraphs } ): Propagating"
87
87
88
88
# reduce memory movement by keeping modules onloaded
89
- with keep_onload_context ():
89
+ with disable_offloading ():
90
90
# do a preliminary pass to trigger modifier hooks
91
91
for batch_idx in tqdm (range (len (dataloader )), desc = calib_desc ):
92
92
inputs = activations .fetch (batch_idx , subgraph .input_names )
Load Diff This file was deleted.
You can’t perform that action at this time.
0 commit comments