Skip to content

Commit 353f556

Browse files
GOavi101Avishek GoswamiHDCharles
authored
Feature/intermediates cache prefetch (#2392)
Optional prefetch was added to the intermediates cache and wired into AWQ when offloading. IntermediatesCache New method iter_prefetch() iterates over batches like iter() but prefetches the next batch in a background thread so onload from the offload device overlaps with use of the current batch, reducing wall‑clock time when offloading to CPU. **AWQ** When offload_device is set, _run_samples() uses cache.iter_prefetch() instead of the cache iterator so CPU→device onload overlaps with the forward pass over cached parent args during smoothing. **Tests** Two tests were added: one that prefetch yields the same batches as iter(), and one that prefetch on an empty cache yields nothing. No new public API; prefetch is used automatically when AWQ offloads. Fix: #2374 --------- Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com> Co-authored-by: Avishek Goswami <avishek.goswami@ibm.com> Co-authored-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
1 parent 822668a commit 353f556

File tree

5 files changed

+114
-27
lines changed

5 files changed

+114
-27
lines changed

src/llmcompressor/core/state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class State:
104104
hardware: Hardware = field(default_factory=Hardware)
105105
loss_masks: list[torch.Tensor] | None = None
106106
current_batch_idx: int = -1
107+
sequential_prefetch: bool = False
107108

108109
@property
109110
def compression_ready(self) -> bool:

src/llmcompressor/modifiers/awq/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,10 @@ def _smooth(
607607

608608
@torch.no_grad()
609609
def _run_samples(self, module: Module) -> list[torch.Tensor]:
610-
outputs = [
611-
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
612-
]
610+
cache = self._parent_args_cache[module]
611+
use_prefetch = active_session().state.sequential_prefetch
612+
batch_iter = cache.iter_prefetch() if use_prefetch else cache
613+
outputs = [module(**batch_kwargs) for batch_kwargs in batch_iter]
613614
return [
614615
# If tuple, assume that first argument is the input
615616
output[0] if isinstance(output, tuple) else output

src/llmcompressor/pipelines/cache.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import warnings
55
from collections import defaultdict
6+
from concurrent.futures import ThreadPoolExecutor
67
from dataclasses import dataclass, fields, is_dataclass
78
from typing import Any, Generator
89
from weakref import WeakKeyDictionary
@@ -196,6 +197,59 @@ def iter(self, input_names: list[str] | None = None) -> Generator[Any, None, Non
196197
for batch_index in range(len(self.batch_intermediates)):
197198
yield self.fetch(batch_index, input_names)
198199

200+
def iter_prefetch(
201+
self, input_names: list[str] | None = None
202+
) -> Generator[Any, None, None]:
203+
"""
204+
Iterate over batches with the next batch prefetched in a background thread.
205+
Overlaps onload from offload_device with consumption of the current batch,
206+
which can reduce wall-clock time when offloading to CPU.
207+
208+
When CUDA is available, uses non_blocking transfers (requires pinned CPU
209+
tensors, set up by _offload_value) and synchronises via CUDA events so the
210+
main stream waits for each H2D copy before running GPU kernels on the data.
211+
212+
Yields the same fetched batch dicts as :meth:`iter`; only the timing
213+
of onloads differs.
214+
"""
215+
num_batches = len(self.batch_intermediates)
216+
if num_batches == 0:
217+
return
218+
219+
# Create a dedicated CUDA stream for H2D transfers so they run on a
220+
# separate stream from the main thread's compute stream. Without this,
221+
# both threads default to the null stream (stream 0) which serializes
222+
# all operations and prevents any overlap.
223+
h2d_stream = torch.cuda.Stream() if torch.cuda.is_available() else None
224+
225+
def _fetch_and_record(batch_index):
226+
event = None
227+
if h2d_stream is not None:
228+
with torch.cuda.stream(h2d_stream):
229+
data = self.fetch(batch_index, input_names)
230+
event = torch.cuda.Event()
231+
event.record(h2d_stream)
232+
else:
233+
data = self.fetch(batch_index, input_names)
234+
return data, event
235+
236+
with ThreadPoolExecutor(max_workers=1) as executor:
237+
future = None
238+
for batch_index in range(num_batches):
239+
if future is not None:
240+
current, event = future.result()
241+
else:
242+
current, event = _fetch_and_record(batch_index)
243+
if batch_index + 1 < num_batches:
244+
future = executor.submit(_fetch_and_record, batch_index + 1)
245+
else:
246+
future = None
247+
# Make the main CUDA stream wait for the background H2D copy
248+
# before any GPU kernel consumes the prefetched tensors
249+
if event is not None:
250+
torch.cuda.current_stream().wait_event(event)
251+
yield current
252+
199253
def __iter__(self) -> Generator[Any, None, None]:
200254
yield from self.iter()
201255

@@ -215,7 +269,14 @@ def _onload_value(cls, intermediate: IntermediateValue) -> Any:
215269

216270
match value:
217271
case torch.Tensor():
218-
return value.to(device=device)
272+
# use non_blocking when source is pinned and target is CUDA so the
273+
# H2D DMA can overlap with GPU compute on a separate CUDA stream
274+
non_blocking = (
275+
value.is_pinned()
276+
and device is not None
277+
and torch.device(device).type == "cuda"
278+
)
279+
return value.to(device=device, non_blocking=non_blocking)
219280
case list():
220281
return [cls._onload_value(v) for v in value]
221282
case tuple():
@@ -259,6 +320,13 @@ def _offload_value(
259320
# move to offload if no hit
260321
offloaded = value.to(device=offload_device)
261322
if offloaded is not value: # avoid circular ref
323+
# pin CPU tensors so onload can use non_blocking DMA
324+
if (
325+
torch.device(offload_device).type == "cpu"
326+
and torch.cuda.is_available()
327+
and not offloaded.is_pinned()
328+
):
329+
offloaded = offloaded.pin_memory()
262330
cls.offload_values[value] = offloaded
263331

264332
return IntermediateValue(

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
from concurrent.futures import ThreadPoolExecutor
32
from typing import TYPE_CHECKING, Iterator
43

54
import torch
@@ -35,30 +34,23 @@ def _get_batches(
3534
num_batches: int,
3635
input_names: list[str],
3736
desc: str,
38-
use_prefetch: bool = False,
37+
sequential_prefetch: bool = False,
3938
) -> Iterator[tuple[int, dict]]:
4039
"""
4140
Yield (batch_idx, inputs) with the next batch optionally prefetched in a
4241
background thread to overlap fetch (onload from offload device) with the
43-
main-thread forward pass.
42+
main-thread forward pass. Delegates to
43+
:meth:`IntermediatesCache.iter_prefetch` when prefetching is enabled.
4444
"""
45-
if not use_prefetch:
46-
for batch_idx in tqdm(range(num_batches), desc=desc):
47-
inputs = activations.fetch(batch_idx, input_names)
48-
yield batch_idx, inputs
49-
return
50-
with ThreadPoolExecutor(max_workers=1) as executor:
51-
future = None
52-
for batch_idx in tqdm(range(num_batches), desc=desc):
53-
if future is not None:
54-
inputs = future.result()
55-
else:
56-
inputs = activations.fetch(batch_idx, input_names)
57-
if batch_idx + 1 < num_batches:
58-
future = executor.submit(activations.fetch, batch_idx + 1, input_names)
59-
else:
60-
future = None
61-
yield batch_idx, inputs
45+
batch_source = (
46+
activations.iter_prefetch(input_names)
47+
if sequential_prefetch
48+
else activations.iter(input_names)
49+
)
50+
for batch_idx, inputs in tqdm(
51+
enumerate(batch_source), total=num_batches, desc=desc
52+
):
53+
yield batch_idx, inputs
6254

6355

6456
@CalibrationPipeline.register("sequential")
@@ -139,22 +131,24 @@ def __call__(
139131
else:
140132
session.state.loss_masks = None
141133

134+
sequential_prefetch = getattr(dataset_args, "sequential_prefetch", False)
135+
session.state.sequential_prefetch = sequential_prefetch
136+
142137
for subgraph_index, subgraph in enumerate(subgraphs):
143138
# prepare tqdm description texts
144139
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
145140
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
146141

147142
# reduce memory movement by keeping modules onloaded
148143
num_batches = len(dataloader)
149-
use_prefetch = getattr(dataset_args, "sequential_prefetch", False)
150144
with disable_offloading():
151145
# do a preliminary pass to trigger modifier hooks
152146
for batch_idx, inputs in _get_batches(
153147
activations,
154148
num_batches,
155149
subgraph.input_names,
156150
calib_desc,
157-
use_prefetch,
151+
sequential_prefetch,
158152
):
159153
session.state.current_batch_idx = batch_idx
160154
subgraph.forward(model, **inputs)
@@ -169,7 +163,7 @@ def __call__(
169163
num_batches,
170164
subgraph.input_names,
171165
prop_desc,
172-
use_prefetch,
166+
sequential_prefetch,
173167
):
174168
output = subgraph.forward(model, **inputs)
175169
if subgraph_index < num_subgraphs - 1:

tests/llmcompressor/pipelines/test_cache.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,29 @@ def test_initialization(sample_dataloader):
5353
assert isinstance(cache.batch_intermediates[0], dict)
5454

5555

56+
@pytest.mark.unit
57+
def test_iter_prefetch_empty_cache():
58+
"""iter_prefetch yields nothing when cache has no batches."""
59+
cache = IntermediatesCache.empty(0, torch.device("cpu"))
60+
assert list(cache.iter_prefetch()) == []
61+
62+
63+
@pytest.mark.unit
64+
def test_iter_prefetch_matches_iter(sample_cache):
65+
"""iter_prefetch yields the same batch contents as iter."""
66+
67+
def batch_dicts_equal(a: dict, b: dict) -> bool:
68+
if set(a.keys()) != set(b.keys()):
69+
return False
70+
return all(deep_equal(a[k], b[k]) for k in a)
71+
72+
via_iter = list(sample_cache.iter())
73+
via_prefetch = list(sample_cache.iter_prefetch())
74+
assert len(via_iter) == len(via_prefetch)
75+
for i, (b_iter, b_prefetch) in enumerate(zip(via_iter, via_prefetch)):
76+
assert batch_dicts_equal(b_iter, b_prefetch), f"batch {i} differs"
77+
78+
5679
@pytest.mark.unit
5780
def test_fetch_inputs(sample_cache):
5881
fetched = sample_cache.fetch(0, ["input_ids", "attention_mask"])

0 commit comments

Comments
 (0)