Skip to content

Commit ecfe15d

Browse files
committed
move offloading outside of preprocess, which is shared with train
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 97c8d30 commit ecfe15d

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from datetime import datetime
33
from typing import Optional
44

5+
import torch
6+
from compressed_tensors.utils import force_cpu_offload
57
from loguru import logger
68
from torch.utils.data import DataLoader
79
from transformers import PreTrainedModel
@@ -123,6 +125,15 @@ def __init__(
123125
# initialize the model and processor
124126
pre_process(model_args)
125127

128+
# offload to cpu if possible
129+
if "cuda" in str(model_args.oneshot_device) and torch.cuda.is_available():
130+
# TODO: consider renaming function similar to "offload_dispatch_model"
131+
# TODO: modify function to remove any hooks if they already exist (making
132+
# sure to move to cpu when removing hook
133+
force_cpu_offload(model_args.model, model_args.oneshot_device)
134+
else:
135+
logger.warning("CUDA is not available! Compressing model on CPU instead")
136+
126137
# Set instance attributes
127138
self.model = self.model_args.model
128139
self.processor = self.model_args.processor

src/llmcompressor/entrypoints/utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from pathlib import PosixPath
44
from typing import Optional, Tuple
55

6-
import torch
7-
from compressed_tensors.utils import force_cpu_offload
86
from loguru import logger
97
from torch.nn import Module
108
from transformers import (
@@ -64,16 +62,6 @@ def pre_process(model_args: "ModelArguments"):
6462
# untie tie_word_embeddings weights
6563
patch_tied_tensors_bug(model_args.model)
6664

67-
# offload to cpu if possible
68-
if "cuda" in str(model_args.oneshot_device) and torch.cuda.is_available():
69-
# TODO: consider renaming function to something like "offload_dispatch_model"
70-
# TODO: modify function to remove any hooks if they already exist (making sure
71-
# to move to cpu when removing hook
72-
force_cpu_offload(model_args.model, model_args.oneshot_device)
73-
74-
else:
75-
logger.warning("CUDA is not available! Compressing model on CPU instead")
76-
7765
# wrap model.save_pretrained
7866
modify_save_pretrained(model_args.model)
7967

0 commit comments

Comments
 (0)