Skip to content

Commit 382c3e6

Browse files
committed
dispatching
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 61da757 commit 382c3e6

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

src/llmcompressor/entrypoints/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from pathlib import PosixPath
44
from typing import Optional, Tuple
55

6+
import torch
7+
from accelerate import dispatch_model
8+
from accelerate.hooks import remove_hook_from_module
69
from loguru import logger
710
from torch.nn import Module
811
from transformers import (
@@ -27,6 +30,7 @@
2730
)
2831
from llmcompressor.typing import Processor
2932
from llmcompressor.utils.fsdp.helpers import is_fsdp_model
33+
from llmcompressor.utils.offload import has_device_execution
3034

3135

3236
def pre_process(model_args: "ModelArguments"):
@@ -65,6 +69,15 @@ def pre_process(model_args: "ModelArguments"):
6569
# wrap model.save_pretrained
6670
modify_save_pretrained(model_args.model)
6771

72+
# dispatch to oneshot device if loaded onto CPU
73+
# this needs to be done before qparams are initialized
74+
if not has_device_execution(model) and torch.cuda.is_available():
75+
model_args.oneshot_device = model_args.oneshot_device or torch.device("cuda:0")
76+
remove_hook_from_module(model_args.model, recurse=True)
77+
model_args.model = dispatch_model(
78+
model_args.model, main_device=model_args.oneshot_device, force_hooks=True
79+
)
80+
6881

6982
def post_process(
7083
model_args: Optional["ModelArguments"] = None,

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from llmcompressor.modifiers.utils.hooks import HooksMixin
1818
from llmcompressor.utils.helpers import calibration_forward_context, preserve_attr
19+
from llmcompressor.utils.offload import has_device_parameters
1920

2021
__all__ = ["trace_subgraphs", "Subgraph"]
2122

@@ -403,11 +404,12 @@ def get_subgraph_modules(subgraph: Graph, parent_graph: GraphModule) -> List[Mod
403404
def infer_oneshot_device(
404405
model: PreTrainedModel, oneshot_device: Optional[torch.device]
405406
) -> Optional[torch.device]:
406-
if is_gpu_dispatched(model):
407+
if has_device_parameters(model):
407408
logger.warning(
408-
"Calibrating a model dispatched to the gpu can potentially lead to OOM "
409-
"errors. Consider loading the model without a `device_map` and instead "
410-
"executing with `cuda:0` (set `oneshot_device` to override this default)"
409+
"Calibrating a model with gpu parameters using the sequential pipeline can "
410+
"potentially lead to OOM errors. Consider loading the model without a "
411+
"`device_map` and instead executing with `cuda:0` (set `oneshot_device` "
412+
"to override this default)"
411413
)
412414
return None
413415

@@ -417,17 +419,3 @@ def infer_oneshot_device(
417419
logger.info(f"No oneshot_device passed, using {oneshot_device}")
418420

419421
return oneshot_device
420-
421-
422-
def is_gpu_dispatched(model: PreTrainedModel) -> bool:
423-
for module in model.modules():
424-
if any(
425-
param.device not in (torch.device("meta"), torch.device("cpu"))
426-
for param in module.parameters()
427-
):
428-
return True
429-
430-
if has_offloaded_params(module) and module._hf_hook.execution_device != "cpu":
431-
return True
432-
433-
return False

src/llmcompressor/utils/offload.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from accelerate.utils import has_offloaded_params
3+
4+
__all__ = ["has_device_parameters", "has_device_execution_hook", "has_device_execution"]
5+
6+
7+
def has_device_parameters(model: torch.nn.Module) -> bool:
8+
return any(
9+
param.device not in (torch.device("cpu"), torch.device("meta"))
10+
for param in model.parameters()
11+
)
12+
13+
14+
def has_device_execution_hook(model: torch.nn.Module) -> bool:
15+
return any(
16+
has_offloaded_params(module) and module._hf_hook.execution_device != "cpu"
17+
for module in model.modules()
18+
)
19+
20+
21+
def has_device_execution(model: torch.nn.Module) -> bool:
22+
return has_device_execution_hook(model) or has_device_parameters(model)

0 commit comments

Comments
 (0)