Skip to content

Commit 8e58e35

Browse files
committed
remove fallback_to_cpu, use ct utils
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 91b349b commit 8e58e35

File tree

5 files changed

+6
-70
lines changed

5 files changed

+6
-70
lines changed

src/llmcompressor/entrypoints/utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from llmcompressor.args import ModelArguments, RecipeArguments, TrainingArguments
1919
from llmcompressor.core import reset_session
20-
from llmcompressor.pytorch.model_load.helpers import fallback_to_cpu, parse_dtype
20+
from llmcompressor.pytorch.model_load.helpers import parse_dtype
2121
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
2222
modify_save_pretrained,
2323
patch_tied_tensors_bug,
@@ -197,20 +197,12 @@ def initialize_model_from_path(
197197
else model_args.model_name_or_path
198198
)
199199

200-
# Fallback to CPU if GPU requested and not available
201-
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)
202-
203-
device_map = model_args.oneshot_device
204-
if training_args is not None and training_args.do_train:
205-
device_map = "auto"
206-
207200
model_kwargs = {
208201
"config": config,
209202
"cache_dir": model_args.cache_dir,
210203
"revision": model_args.model_revision,
211204
"use_auth_token": True if model_args.use_auth_token else None,
212205
"torch_dtype": parse_dtype(model_args.precision),
213-
"device_map": device_map,
214206
"trust_remote_code": model_args.trust_remote_code_model,
215207
}
216208

@@ -220,10 +212,7 @@ def initialize_model_from_path(
220212
run_compressed=False
221213
)
222214

223-
model = AutoModelForCausalLM.from_pretrained(
224-
model_path,
225-
**model_kwargs,
226-
)
215+
model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
227216
if "sequence_length" in model_kwargs:
228217
model.seqlen = model_kwargs["sequence_length"]
229218

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import tqdm
5+
from compressed_tensors.utils import disable_offloading
56
from loguru import logger
67
from torch.utils.data.dataloader import DataLoader
78

@@ -15,10 +16,7 @@
1516
to_next_layer_kwargs,
1617
)
1718
from llmcompressor.pipelines.registry import CalibrationPipeline
18-
from llmcompressor.pipelines.sequential.helpers import (
19-
disable_offloading,
20-
get_targets_from_modifiers,
21-
)
19+
from llmcompressor.pipelines.sequential.helpers import get_targets_from_modifiers
2220
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
2321

2422
if TYPE_CHECKING:

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, Dict, List, Optional, Set
66

77
import torch
8-
from accelerate.hooks import AlignDevicesHook
98
from compressed_tensors import has_offloaded_params
109
from compressed_tensors.quantization import find_name_or_class_matches
1110
from loguru import logger
@@ -24,12 +23,7 @@
2423

2524
from .ast_helpers import autowrap_forwards
2625

27-
__all__ = [
28-
"trace_subgraphs",
29-
"Subgraph",
30-
"get_targets_from_modifiers",
31-
"disable_offloading",
32-
]
26+
__all__ = ["trace_subgraphs", "Subgraph", "get_targets_from_modifiers"]
3327

3428

3529
@dataclass
@@ -491,30 +485,3 @@ def is_ancestor(module: Module) -> bool:
491485

492486
is_ancestor(model)
493487
return ancestors
494-
495-
496-
@contextlib.contextmanager
497-
def disable_offloading():
498-
"""
499-
Keep modules onloaded and disable offloading until this context exits.
500-
Affects modules which have been hooked with accelerate's `AlignDevicesHook`
501-
"""
502-
original_pre_forward = AlignDevicesHook.pre_forward
503-
onloaded_modules = dict()
504-
505-
# onload once and disable any future onloading/offloading steps
506-
def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
507-
ret = original_pre_forward(self, module, *args, **kwargs)
508-
if module not in onloaded_modules:
509-
onloaded_modules[module] = (self, self.offload)
510-
self.offload = False
511-
return ret
512-
513-
# use the patched pre_forward function within the context
514-
with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
515-
yield
516-
517-
# manually offload all modules that were onloaded
518-
for module, (hook, offload) in onloaded_modules.items():
519-
hook.offload = offload
520-
hook.post_forward(module, None)

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING
22

33
import torch
4-
from compressed_tensors.utils import get_execution_device
4+
from compressed_tensors.utils import disable_offloading, get_execution_device
55
from loguru import logger
66
from torch.utils.data.dataloader import DataLoader
77
from tqdm import tqdm
@@ -11,7 +11,6 @@
1111
from llmcompressor.pipelines.cache import IntermediatesCache
1212
from llmcompressor.pipelines.registry import CalibrationPipeline
1313
from llmcompressor.pipelines.sequential.helpers import (
14-
disable_offloading,
1514
get_targets_from_modifiers,
1615
trace_subgraphs,
1716
)

src/llmcompressor/pytorch/model_load/helpers.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
__all__ = [
1717
"copy_python_files_from_model_cache",
18-
"fallback_to_cpu",
1918
"parse_dtype",
2019
"get_session_model",
2120
"get_completed_stages",
@@ -71,22 +70,6 @@ def save_checkpoint(
7170
compressor.decompress_model(model)
7271

7372

74-
def fallback_to_cpu(device: str) -> str:
75-
"""
76-
Takes in a device string and forces it to cpu if cuda is not available
77-
78-
:param device: device id to check
79-
:return: device modified for CUDA status
80-
"""
81-
if "cuda" in device and not torch.cuda.is_available():
82-
logger.warning(
83-
f"Requested {device} but CUDA is not available, falling back to CPU"
84-
)
85-
return "cpu"
86-
87-
return device
88-
89-
9073
def parse_dtype(dtype_arg: Union[str, torch.dtype]) -> torch.dtype:
9174
"""
9275
:param dtype_arg: dtype or string to parse

0 commit comments

Comments
 (0)