Skip to content

Commit 6705bf4

Browse files
touchups based on remaining steps
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 1aea4dd commit 6705bf4

File tree

6 files changed

+70
-32
lines changed

6 files changed

+70
-32
lines changed

examples/quantization_w4a16/llama3_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
from compressed_tensors import force_cpu_offload
32
from datasets import load_dataset
43
from transformers import AutoModelForCausalLM, AutoTokenizer
54

@@ -11,10 +10,9 @@
1110

1211
model = AutoModelForCausalLM.from_pretrained(
1312
MODEL_ID,
14-
# device_map="auto",
13+
device_map="cpu",
1514
torch_dtype="auto",
1615
)
17-
force_cpu_offload(model, execution_device=torch.device("cuda"))
1816
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
1917

2018
# Select calibration dataset.
@@ -67,6 +65,7 @@ def tokenize(sample):
6765
recipe=recipe,
6866
max_seq_length=MAX_SEQUENCE_LENGTH,
6967
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
68+
oneshot_device=torch.device("cuda") if torch.cuda.is_available() else None,
7069
)
7170

7271
# Confirm generations of the quantized model look sane.

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from datetime import datetime
33
from typing import Optional
44

5+
import torch
6+
from compressed_tensors import force_cpu_offload
7+
from compressed_tensors.utils import get_execution_device
58
from loguru import logger
69
from torch.utils.data import DataLoader
710
from transformers import PreTrainedModel
@@ -10,7 +13,11 @@
1013
from llmcompressor.core.session_functions import active_session
1114
from llmcompressor.datasets import get_calibration_dataloader
1215
from llmcompressor.entrypoints.utils import post_process, pre_process
13-
from llmcompressor.pipelines.registry import CalibrationPipeline
16+
from llmcompressor.pipelines import (
17+
CalibrationPipeline,
18+
LayerSequentialPipeline,
19+
SequentialPipeline,
20+
)
1421

1522
__all__ = ["Oneshot", "oneshot"]
1623

@@ -186,6 +193,35 @@ def apply_recipe_modifiers(
186193
user_pipeline = self.dataset_args.pipeline
187194
modifiers = session.get_modifiers()
188195
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
196+
197+
model_exec_device = get_execution_device(self.model)
198+
199+
# Sequential pipelines onload models layer by layer to minimize GPU memory usage
200+
if isinstance(pipeline, (SequentialPipeline, LayerSequentialPipeline)):
201+
# unless pure cpu run, throw warning if model lives on oneshot_device
202+
if (
203+
model_exec_device
204+
== self.model_args.oneshot_device
205+
!= torch.device("cpu")
206+
):
207+
logger.warning(
208+
f"Model device {model_exec_device} is the same as oneshot"
209+
" execution device. If you encounter OOM errors, consider"
210+
" loading the model up on CPU, so that more memory is available"
211+
" for the oneshot algorithm to run on GPU. Example available at"
212+
" examples/quantization_w4a16/llama3_example.py"
213+
)
214+
215+
# set cpu offload for model
216+
elif (
217+
model_exec_device
218+
== torch.device("cpu")
219+
!= self.model_args.oneshot_device
220+
):
221+
force_cpu_offload(
222+
self.model, execution_devce=self.model_args.oneshot_device
223+
)
224+
189225
pipeline(self.model, calibration_dataloader, self.dataset_args)
190226

191227
session.finalize()

src/llmcompressor/modifiers/awq/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
__all__ = ["AWQModifier"]
3535

3636

37-
# TODO (Brian INFERENG-531) Add support for offloaded models
3837
class AWQModifier(Modifier, QuantizationMixin):
3938
"""
4039
Implements the AWQ (Activation-Weighted Quantization) algorithm,

src/llmcompressor/pipelines/basic/pipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def __call__(
3737
:param dataloader: loads data for calibration
3838
:param dataset_args: dataset arguments relevant to pipelines
3939
"""
40-
# TODO: warn about cpu offloading
41-
4240
model_device = get_execution_device(model)
4341

4442
LifecycleCallbacks.calibration_epoch_start()

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
to_next_layer_kwargs,
1515
)
1616
from llmcompressor.pipelines.registry import CalibrationPipeline
17-
from llmcompressor.pipelines.sequential.helpers import get_targets_from_modifiers
17+
from llmcompressor.pipelines.sequential.helpers import (
18+
get_targets_from_modifiers,
19+
keep_onload_context,
20+
)
1821
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
1922

2023
if TYPE_CHECKING:
@@ -73,29 +76,34 @@ def __call__(
7376
calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating"
7477
prop_desc = f"({layer_index + 1}/{num_layers}): Propagating"
7578

76-
# do a preliminary pass to trigger modifier hooks
77-
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
78-
inputs = intermediates.fetch(batch_idx)
79-
layer(**inputs)
80-
81-
LifecycleCallbacks.sequential_epoch_end()
82-
83-
# this pass does not trigger modifier hooks
84-
# and is only used for capturing outputs from newly compressed modules
85-
with HooksMixin.disable_hooks():
86-
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
79+
# reduce memory movement by keeping modules onloaded
80+
with keep_onload_context():
81+
# do a preliminary pass to trigger modifier hooks
82+
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
8783
inputs = intermediates.fetch(batch_idx)
88-
output = layer(**inputs)
89-
90-
if layer_index < num_layers - 1:
91-
next_layer = layers[layer_index + 1]
92-
output = to_next_layer_kwargs(output, next_layer)
93-
output = maybe_inject_pos_embeddings(
94-
output, next_layer, inputs
95-
)
96-
97-
intermediates.delete(batch_idx)
98-
intermediates.update(batch_idx, output)
84+
layer(**inputs)
85+
86+
LifecycleCallbacks.sequential_epoch_end()
87+
88+
# this pass does not trigger modifier hooks
89+
# and is only used for capturing outputs from
90+
# newly compressed modules
91+
with HooksMixin.disable_hooks():
92+
for batch_idx in tqdm.tqdm(
93+
range(len(dataloader)), desc=prop_desc
94+
):
95+
inputs = intermediates.fetch(batch_idx)
96+
output = layer(**inputs)
97+
98+
if layer_index < num_layers - 1:
99+
next_layer = layers[layer_index + 1]
100+
output = to_next_layer_kwargs(output, next_layer)
101+
output = maybe_inject_pos_embeddings(
102+
output, next_layer, inputs
103+
)
104+
105+
intermediates.delete(batch_idx)
106+
intermediates.update(batch_idx, output)
99107

100108
# redundant, finish any remaining compression
101109
LifecycleCallbacks.calibration_epoch_end()

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def __call__(
5252
"""
5353
session = active_session()
5454

55-
# TODO: warn about not cpu offloading
56-
5755
# prepare to trace subgraphs
5856
modifiers = session.get_modifiers()
5957
sequential_targets = get_targets_from_modifiers(modifiers, model)

0 commit comments

Comments
 (0)