Skip to content

Commit 457af3e

Browse files
author
Avishek Goswami
committed
Add optional prefetch to intermediates cache; enable for AWQ when offloading
- IntermediatesCache.iter_prefetch() overlaps onload of next batch with consumption of current batch via a background thread - AWQ _run_samples uses iter_prefetch when offload_device is set to overlap CPU->device transfer with module forward passes - Add test_iter_prefetch_matches_iter to verify prefetch yields same results as iter Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent a33d4ff commit 457af3e

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def _smooth(
608608
@torch.no_grad()
609609
def _run_samples(self, module: Module) -> list[torch.Tensor]:
610610
cache = self._parent_args_cache[module]
611+
# When offloading, prefetch overlaps CPU->device onload with forward pass.
611612
batch_iter = (
612613
cache.iter_prefetch()
613614
if self.offload_device is not None

src/llmcompressor/pipelines/cache.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def iter_prefetch(
203203
"""
204204
Iterate over batches with the next batch prefetched in a background thread.
205205
Overlaps onload from offload_device with consumption of the current batch,
206-
which can reduce latency when offloading to CPU.
206+
which can reduce wall-clock time when offloading to CPU.
207207
208208
Yields the same fetched batch dicts as :meth:`iter`; only the timing
209209
of onloads differs.
@@ -215,15 +215,16 @@ def iter_prefetch(
215215
future = None
216216
for batch_index in range(num_batches):
217217
if future is not None:
218-
yield future.result()
218+
current = future.result()
219219
else:
220-
yield self.fetch(batch_index, input_names)
220+
current = self.fetch(batch_index, input_names)
221221
if batch_index + 1 < num_batches:
222222
future = executor.submit(
223223
self.fetch, batch_index + 1, input_names
224224
)
225225
else:
226226
future = None
227+
yield current
227228

228229
def __iter__(self) -> Generator[Any, None, None]:
229230
yield from self.iter()

tests/llmcompressor/pipelines/test_cache.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def test_initialization(sample_dataloader):
5353
assert isinstance(cache.batch_intermediates[0], dict)
5454

5555

56+
@pytest.mark.unit
57+
def test_iter_prefetch_empty_cache():
58+
"""iter_prefetch yields nothing when cache has no batches."""
59+
cache = IntermediatesCache.empty(0, torch.device("cpu"))
60+
assert list(cache.iter_prefetch()) == []
61+
62+
5663
@pytest.mark.unit
5764
def test_iter_prefetch_matches_iter(sample_cache):
5865
"""iter_prefetch yields the same batch contents as iter."""

0 commit comments

Comments
 (0)