Skip to content

Commit 9128c2b

Browse files
author
Avishek Goswami
committed
Add optional prefetch to intermediates cache; enable for AWQ when offloading
- IntermediatesCache.iter_prefetch() overlaps onload of next batch with consumption of current batch via a background thread - AWQ _run_samples uses iter_prefetch when offload_device is set to overlap CPU->device transfer with module forward passes - Add test_iter_prefetch_matches_iter to verify prefetch yields same results as iter Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent a33d4ff commit 9128c2b

File tree

4 files changed

+30
-30
lines changed

4 files changed

+30
-30
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ class AWQModifier(Modifier, QuantizationMixin):
138138
requirements but requires more time to move data between cpu and execution
139139
device. Defaults to None, so cached args are not offloaded. Consider setting
140140
to torch.device("cpu") if you are encountering OOM errors
141+
:param prefetch: when offloading, prefetch the next batch in a background thread
142+
to overlap CPU-to-device onload with the forward pass, reducing wall-clock
143+
time. Default False; set True when offload_device is set and GPU memory
144+
allows two batches on device simultaneously
141145
:param duo_scaling: whether to use duo scaling, which uses both input activations
142146
and weights to determine the scaling factor. Defaults to True
143147
If True, both activations and weights are used.
@@ -157,6 +161,7 @@ class AWQModifier(Modifier, QuantizationMixin):
157161
sequential_targets: str | list[str] | None = None
158162
mappings: list[AWQMapping] | None = None
159163
offload_device: torch.device | None | Sentinel = Sentinel("not_provided")
164+
prefetch: bool = False
160165
duo_scaling: bool | Literal["both"] = True
161166
n_grid: int = 20
162167

@@ -608,11 +613,8 @@ def _smooth(
608613
@torch.no_grad()
609614
def _run_samples(self, module: Module) -> list[torch.Tensor]:
610615
cache = self._parent_args_cache[module]
611-
batch_iter = (
612-
cache.iter_prefetch()
613-
if self.offload_device is not None
614-
else cache
615-
)
616+
# Prefetch overlaps CPU->device onload with forward pass when offloading.
617+
batch_iter = cache.iter_prefetch() if self.prefetch else cache
616618
outputs = [module(**batch_kwargs) for batch_kwargs in batch_iter]
617619
return [
618620
# If tuple, assume that first argument is the input

src/llmcompressor/pipelines/cache.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def iter_prefetch(
203203
"""
204204
Iterate over batches with the next batch prefetched in a background thread.
205205
Overlaps onload from offload_device with consumption of the current batch,
206-
which can reduce latency when offloading to CPU.
206+
which can reduce wall-clock time when offloading to CPU.
207207
208208
Yields the same fetched batch dicts as :meth:`iter`; only the timing
209209
of onloads differs.
@@ -215,15 +215,14 @@ def iter_prefetch(
215215
future = None
216216
for batch_index in range(num_batches):
217217
if future is not None:
218-
yield future.result()
218+
current = future.result()
219219
else:
220-
yield self.fetch(batch_index, input_names)
220+
current = self.fetch(batch_index, input_names)
221221
if batch_index + 1 < num_batches:
222-
future = executor.submit(
223-
self.fetch, batch_index + 1, input_names
224-
)
222+
future = executor.submit(self.fetch, batch_index + 1, input_names)
225223
else:
226224
future = None
225+
yield current
227226

228227
def __iter__(self) -> Generator[Any, None, None]:
229228
yield from self.iter()

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
from concurrent.futures import ThreadPoolExecutor
32
from typing import TYPE_CHECKING, Iterator
43

54
import torch
@@ -40,25 +39,18 @@ def _get_batches(
4039
"""
4140
Yield (batch_idx, inputs) with the next batch optionally prefetched in a
4241
background thread to overlap fetch (onload from offload device) with the
43-
main-thread forward pass.
42+
main-thread forward pass. Delegates to
43+
:meth:`IntermediatesCache.iter_prefetch` when prefetching is enabled.
4444
"""
45-
if not use_prefetch:
46-
for batch_idx in tqdm(range(num_batches), desc=desc):
47-
inputs = activations.fetch(batch_idx, input_names)
48-
yield batch_idx, inputs
49-
return
50-
with ThreadPoolExecutor(max_workers=1) as executor:
51-
future = None
52-
for batch_idx in tqdm(range(num_batches), desc=desc):
53-
if future is not None:
54-
inputs = future.result()
55-
else:
56-
inputs = activations.fetch(batch_idx, input_names)
57-
if batch_idx + 1 < num_batches:
58-
future = executor.submit(activations.fetch, batch_idx + 1, input_names)
59-
else:
60-
future = None
61-
yield batch_idx, inputs
45+
batch_source = (
46+
activations.iter_prefetch(input_names)
47+
if use_prefetch
48+
else activations.iter(input_names)
49+
)
50+
for batch_idx, inputs in tqdm(
51+
enumerate(batch_source), total=num_batches, desc=desc
52+
):
53+
yield batch_idx, inputs
6254

6355

6456
@CalibrationPipeline.register("sequential")

tests/llmcompressor/pipelines/test_cache.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def test_initialization(sample_dataloader):
5353
assert isinstance(cache.batch_intermediates[0], dict)
5454

5555

56+
@pytest.mark.unit
57+
def test_iter_prefetch_empty_cache():
58+
"""iter_prefetch yields nothing when cache has no batches."""
59+
cache = IntermediatesCache.empty(0, torch.device("cpu"))
60+
assert list(cache.iter_prefetch()) == []
61+
62+
5663
@pytest.mark.unit
5764
def test_iter_prefetch_matches_iter(sample_cache):
5865
"""iter_prefetch yields the same batch contents as iter."""

0 commit comments

Comments
 (0)