Skip to content

Commit cf09876

Browse files
committed
appropriate oneshot_device for determinig onloading
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 63d1934 commit cf09876

File tree

9 files changed

+95
-65
lines changed

9 files changed

+95
-65
lines changed

src/llmcompressor/entrypoints/utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from transformers.utils.quantization_config import CompressedTensorsConfig
1616

1717
from llmcompressor.args import ModelArguments, TrainingArguments
18-
from llmcompressor.pytorch.model_load.helpers import fallback_to_cpu, parse_dtype
18+
from llmcompressor.pytorch.model_load.helpers import parse_dtype
1919
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
2020
modify_save_pretrained,
2121
patch_tied_tensors_bug,
@@ -178,20 +178,12 @@ def initialize_model_from_path(
178178
else model_args.model_name_or_path
179179
)
180180

181-
# Fallback to CPU if GPU requested and not available
182-
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)
183-
184-
device_map = model_args.oneshot_device
185-
if training_args is not None and training_args.do_train:
186-
device_map = "auto"
187-
188181
model_kwargs = {
189182
"config": config,
190183
"cache_dir": model_args.cache_dir,
191184
"revision": model_args.model_revision,
192185
"use_auth_token": True if model_args.use_auth_token else None,
193186
"torch_dtype": parse_dtype(model_args.precision),
194-
"device_map": device_map,
195187
"trust_remote_code": model_args.trust_remote_code_model,
196188
}
197189

src/llmcompressor/modifiers/obcq/sgpt_mixin.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
170170
state.data.calib,
171171
self.sequential_targets,
172172
self.ignore,
173-
self,
173+
None, # TODO: pass in oneshot_device argument https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
174+
self, # TODO: use callbacks https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
174175
)
175176
return True
176177

@@ -186,7 +187,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
186187
state.model,
187188
state.data.calib,
188189
self.sequential_targets,
189-
self,
190+
None, # TODO: pass in oneshot_device argument https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
191+
self, # TODO: use callbacks https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
190192
)
191193
return True
192194

@@ -200,7 +202,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
200202
"Falling back to basic pipeline, which requires extra memory and "
201203
"may result in decreased accuracy"
202204
)
203-
run_basic(state.model, state.data.calib, self)
205+
run_basic(
206+
state.model,
207+
state.data.calib,
208+
None, # TODO: pass in oneshot_device argument https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
209+
self, # TODO: use callbacks https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
210+
)
204211
return True
205212

206213
def _infer_sequential_targets(

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
236236
state.data.calib,
237237
self.sequential_targets,
238238
self.ignore,
239-
self,
239+
None, # TODO: pass in oneshot_device argument https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
240+
self, # TODO: use callbacks https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
240241
)
241242
return True
242243

@@ -257,7 +258,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
257258
state.model,
258259
state.data.calib,
259260
self.sequential_targets,
260-
self,
261+
None, # TODO: pass in oneshot_device argument https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
262+
self, # TODO: use callbacks https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
261263
)
262264
return True
263265

@@ -272,7 +274,12 @@ def on_initialize(self, state: State, **kwargs) -> bool:
272274
"may result in decreased accuracy. Consider using "
273275
"`offload_hessians=True`"
274276
)
275-
run_basic(state.model, state.data.calib, self)
277+
run_basic(
278+
state.model,
279+
state.data.calib,
280+
None, # TODO: pass in oneshot_device argument https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
281+
self, # TODO: use callbacks https://github.com/vllm-project/llm-compressor/pull/1279 # noqa: E501
282+
)
276283
return True
277284

278285
def on_finalize(self, state: State, **kwargs) -> bool:

src/llmcompressor/pipelines/basic/pipeline.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44
import torch.utils.data.dataloader
55
import tqdm
6-
from compressed_tensors.utils import get_execution_device
6+
from loguru import logger
7+
from transformers import PreTrainedModel
78

89
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
910
from llmcompressor.pytorch.utils.helpers import tensors_to_device
@@ -16,8 +17,9 @@
1617

1718

1819
def run_pipeline(
19-
model: torch.nn.Module,
20+
model: PreTrainedModel,
2021
dataloader: torch.utils.data.DataLoader,
22+
oneshot_device: Optional[torch.device],
2123
callback_modifier: Optional["Modifier"] = None,
2224
):
2325
"""
@@ -32,12 +34,16 @@ def run_pipeline(
3234
:param dataloader: loads data for calibration
3335
:param callback_modifier: Temporary HACK which should be replaced by event callback
3436
"""
35-
model_device = get_execution_device(model)
37+
if oneshot_device is not None:
38+
logger.warning(
39+
"Basic pipeline does not utilize `oneshot_device` argument, instead use "
40+
"`from_pretrained(device_map=...)` to determine onloading behavior"
41+
)
3642

3743
with calibration_forward_context(model):
3844
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
3945
batch = apply_pad_mask_to_batch(batch)
40-
batch = tensors_to_device(batch, model_device)
46+
batch = tensors_to_device(batch, model.device)
4147
model(**batch)
4248

4349
# TODO: replace with a lifecycle event

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
maybe_inject_pos_embeddings,
1313
to_next_layer_kwargs,
1414
)
15+
from llmcompressor.pipelines.sequential.helpers import infer_oneshot_device
1516
from llmcompressor.utils.helpers import align_modules, calibration_forward_context
1617

1718
if TYPE_CHECKING:
@@ -24,6 +25,7 @@ def run_pipeline(
2425
model: torch.nn.Module,
2526
dataloader: torch.utils.data.DataLoader,
2627
sequential_targets: List[str],
28+
oneshot_device: Optional[torch.device],
2729
callback_modifier: Optional["Modifier"] = None,
2830
):
2931
"""
@@ -46,8 +48,13 @@ def run_pipeline(
4648
:param model: model being calibrated
4749
:param dataloader: loads data for calibration
4850
:param sequential_targets: patterns which match to the layer modules of the model
51+
:param oneshot_device: device to onload layers ontop, uses device_map if None
4952
:param callback_modifier: Temporary HACK which should be replaced by event callback
5053
"""
54+
# if the model is dispatched, use the dispatch to determine onloading, return None
55+
# otherwise, infer a oneshot device (either user passed or the first available gpu)
56+
oneshot_device = infer_oneshot_device(model, oneshot_device)
57+
5158
# find layers
5259
layers = match_modules(model, sequential_targets)
5360

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import inspect
22
from collections import deque
33
from dataclasses import dataclass
4-
from typing import Any, Callable, Dict, List, Set, Union
4+
from typing import Any, Callable, Dict, List, Optional, Set, Union
55

6+
import torch
67
from compressed_tensors import has_offloaded_params
78
from compressed_tensors.quantization import find_name_or_class_matches
9+
from loguru import logger
810
from torch.fx import Graph, GraphModule, Node
911
from torch.fx.proxy import Argument
1012
from torch.nn import Module
@@ -399,3 +401,33 @@ def get_subgraph_modules(subgraph: Graph, parent_graph: GraphModule) -> List[Mod
399401
modules_ops: List[Node] = subgraph.find_nodes(op="call_module")
400402
called_modules = [parent_graph.get_submodule(op.target) for op in modules_ops]
401403
return list({m for module in called_modules for m in module.modules()})
404+
405+
406+
def infer_oneshot_device(
407+
model: PreTrainedModel, oneshot_device: Optional[torch.device]
408+
) -> Optional[torch.device]:
409+
if is_gpu_dispatched(model):
410+
logger.warning(
411+
"Calibrating a model dispatched to the gpu can potentially lead to OOM "
412+
"errors. Consider loading the model without a `device_map` and instead "
413+
"executing with `cuda:0` (set `oneshot_device` to override this default)"
414+
)
415+
return None
416+
417+
elif oneshot_device is None:
418+
has_cuda = torch.cuda.is_available()
419+
oneshot_device = torch.device("cuda:0") if has_cuda else torch.device("cpu")
420+
logger.info(f"No oneshot_device passed, using {oneshot_device}")
421+
422+
return oneshot_device
423+
424+
425+
def is_gpu_dispatched(model: PreTrainedModel) -> bool:
426+
for module in model.modules():
427+
if any(param.device not in ("meta", "cpu") for param in module.parameters()):
428+
return True
429+
430+
if has_offloaded_params(module) and module._hf_hook.execution_device != "cpu":
431+
return True
432+
433+
return False

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
import torch
44
import torch.utils.data.dataloader
5-
from compressed_tensors.utils import get_execution_device
65
from tqdm import tqdm
6+
from transformers import PreTrainedModel
77

88
from llmcompressor.modifiers.utils.hooks import HooksMixin
99
from llmcompressor.pipelines.cache import IntermediatesCache
10-
from llmcompressor.pipelines.sequential.helpers import trace_subgraphs
10+
from llmcompressor.pipelines.sequential.helpers import (
11+
infer_oneshot_device,
12+
trace_subgraphs,
13+
)
1114
from llmcompressor.utils.helpers import align_modules, calibration_forward_context
1215

1316
if TYPE_CHECKING:
@@ -17,10 +20,11 @@
1720

1821

1922
def run_pipeline(
20-
model: torch.nn.Module,
23+
model: PreTrainedModel,
2124
dataloader: torch.utils.data.DataLoader,
2225
sequential_targets: List[str],
2326
ignore: List[str],
27+
oneshot_device: Optional[torch.device],
2428
callback_modifier: Optional["Modifier"] = None,
2529
):
2630
"""
@@ -45,16 +49,22 @@ def run_pipeline(
4549
:param dataloader: loads data for calibration
4650
:param sequential_targets: patterns which match to the layer modules of the model
4751
:param ignore: patterns which match to modules which should be ignored by tracing
52+
:param oneshot_device: device to onload layers ontop, uses device_map if None
53+
:param callback_modifier: Temporary HACK which should be replaced by event callback
4854
"""
55+
# if the model is dispatched, use the dispatch to determine onloading, return None
56+
# otherwise, infer a oneshot device (either user passed or the first available gpu)
57+
oneshot_device = infer_oneshot_device(model, oneshot_device)
58+
4959
# trace subgraphs
5060
sample_input = next(iter(dataloader))
5161
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
5262

53-
with calibration_forward_context(model):
54-
# prepare intermediates cache
55-
model_device = get_execution_device(model)
56-
intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)
63+
# prepare intermediates cache
64+
model_device = oneshot_device or model.device
65+
intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)
5766

67+
with calibration_forward_context(model):
5868
num_subgraphs = len(subgraphs)
5969
for subgraph_index, subgraph in enumerate(subgraphs):
6070
# prepare tqdm description texts
@@ -64,7 +74,7 @@ def run_pipeline(
6474
# compile subgraph forward function
6575
forward_function = subgraph.compile_forward()
6676

67-
with align_modules(subgraph.modules):
77+
with align_modules(subgraph.modules, oneshot_device):
6878
# do an preliminary pass to trigger modifier hooks
6979
for batch_index in tqdm(range(len(dataloader)), desc=calib_desc):
7080
inputs = intermediates.fetch(batch_index, subgraph.input_names)

src/llmcompressor/pytorch/model_load/helpers.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
__all__ = [
1717
"copy_python_files_from_model_cache",
18-
"fallback_to_cpu",
1918
"parse_dtype",
2019
"get_session_model",
2120
"get_completed_stages",
@@ -41,22 +40,6 @@ def save_checkpoint(
4140
processor.save_pretrained(save_path)
4241

4342

44-
def fallback_to_cpu(device: str) -> str:
45-
"""
46-
Takes in a device string and forces it to cpu if cuda is not available
47-
48-
:param device: device id to check
49-
:return: device modified for CUDA status
50-
"""
51-
if "cuda" in device and not torch.cuda.is_available():
52-
logger.warning(
53-
f"Requested {device} but CUDA is not available, falling back to CPU"
54-
)
55-
return "cpu"
56-
57-
return device
58-
59-
6043
def parse_dtype(dtype_arg: Union[str, torch.dtype]) -> torch.dtype:
6144
"""
6245
:param dtype_arg: dtype or string to parse

src/llmcompressor/utils/helpers.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import numpy
2525
import torch
2626
from compressed_tensors.quantization import disable_quantization, enable_quantization
27-
from compressed_tensors.utils import has_offloaded_params
27+
from compressed_tensors.utils import align_module_device
2828
from loguru import logger
2929
from transformers import PreTrainedModel
3030

@@ -1068,22 +1068,8 @@ def preserve_attr(base: object, attr: str):
10681068
def align_modules(
10691069
modules: Iterable[torch.nn.Module], execution_device: Optional[torch.device] = None
10701070
):
1071-
original_devices = {}
1072-
can_offload = [module for module in modules if has_offloaded_params(module)]
1073-
1074-
for module in can_offload:
1075-
if execution_device is not None:
1076-
module._hf_hook.execution_device = execution_device
1077-
original_devices[module] = module._hf_hook.execution_device
1078-
1079-
module._hf_hook.pre_forward(module)
1080-
module._hf_hook.offload = False
1081-
1082-
yield
1083-
1084-
for module in can_offload:
1085-
if execution_device is not None:
1086-
module._hf_hook.execution_device = original_devices[module]
1087-
1088-
module._hf_hook.offload = True
1089-
module._hf_hook.post_forward(module, None)
1071+
with contextlib.ExitStack() as stack:
1072+
[
1073+
stack.enter_context(align_module_device(module, execution_device))
1074+
for module in modules
1075+
]

0 commit comments

Comments
 (0)