Skip to content

Commit b6a3282

Browse files
author
Avishek Goswami
committed
Address review: move deep_equal, make prefetch a global setting via State
Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 9128c2b commit b6a3282

File tree

5 files changed

+78
-40
lines changed

5 files changed

+78
-40
lines changed

src/llmcompressor/core/state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class State:
114114
_last_log_step: float | int | None = None
115115
loss_masks: list[torch.Tensor] | None = None
116116
current_batch_idx: int = -1
117+
sequential_prefetch: bool = False
117118

118119
@property
119120
def compression_ready(self) -> bool:

src/llmcompressor/modifiers/awq/base.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ class AWQModifier(Modifier, QuantizationMixin):
138138
requirements but requires more time to move data between cpu and execution
139139
device. Defaults to None, so cached args are not offloaded. Consider setting
140140
to torch.device("cpu") if you are encountering OOM errors
141-
:param prefetch: when offloading, prefetch the next batch in a background thread
142-
to overlap CPU-to-device onload with the forward pass, reducing wall-clock
143-
time. Default False; set True when offload_device is set and GPU memory
144-
allows two batches on device simultaneously
145141
:param duo_scaling: whether to use duo scaling, which uses both input activations
146142
and weights to determine the scaling factor. Defaults to True
147143
If True, both activations and weights are used.
@@ -161,7 +157,6 @@ class AWQModifier(Modifier, QuantizationMixin):
161157
sequential_targets: str | list[str] | None = None
162158
mappings: list[AWQMapping] | None = None
163159
offload_device: torch.device | None | Sentinel = Sentinel("not_provided")
164-
prefetch: bool = False
165160
duo_scaling: bool | Literal["both"] = True
166161
n_grid: int = 20
167162

@@ -613,8 +608,8 @@ def _smooth(
613608
@torch.no_grad()
614609
def _run_samples(self, module: Module) -> list[torch.Tensor]:
615610
cache = self._parent_args_cache[module]
616-
# Prefetch overlaps CPU->device onload with forward pass when offloading.
617-
batch_iter = cache.iter_prefetch() if self.prefetch else cache
611+
use_prefetch = active_session().state.sequential_prefetch
612+
batch_iter = cache.iter_prefetch() if use_prefetch else cache
618613
outputs = [module(**batch_kwargs) for batch_kwargs in batch_iter]
619614
return [
620615
# If tuple, assume that first argument is the input

src/llmcompressor/pipelines/cache.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,23 +205,49 @@ def iter_prefetch(
205205
Overlaps onload from offload_device with consumption of the current batch,
206206
which can reduce wall-clock time when offloading to CPU.
207207
208+
When CUDA is available, uses non_blocking transfers (requires pinned CPU
209+
tensors, set up by _offload_value) and synchronises via CUDA events so the
210+
main stream waits for each H2D copy before running GPU kernels on the data.
211+
208212
Yields the same fetched batch dicts as :meth:`iter`; only the timing
209213
of onloads differs.
210214
"""
211215
num_batches = len(self.batch_intermediates)
212216
if num_batches == 0:
213217
return
218+
219+
# Create a dedicated CUDA stream for H2D transfers so they run on a
220+
# separate stream from the main thread's compute stream. Without this,
221+
# both threads default to the null stream (stream 0) which serializes
222+
# all operations and prevents any overlap.
223+
h2d_stream = torch.cuda.Stream() if torch.cuda.is_available() else None
224+
225+
def _fetch_and_record(batch_index):
226+
event = None
227+
if h2d_stream is not None:
228+
with torch.cuda.stream(h2d_stream):
229+
data = self.fetch(batch_index, input_names)
230+
event = torch.cuda.Event()
231+
event.record(h2d_stream)
232+
else:
233+
data = self.fetch(batch_index, input_names)
234+
return data, event
235+
214236
with ThreadPoolExecutor(max_workers=1) as executor:
215237
future = None
216238
for batch_index in range(num_batches):
217239
if future is not None:
218-
current = future.result()
240+
current, event = future.result()
219241
else:
220-
current = self.fetch(batch_index, input_names)
242+
current, event = _fetch_and_record(batch_index)
221243
if batch_index + 1 < num_batches:
222-
future = executor.submit(self.fetch, batch_index + 1, input_names)
244+
future = executor.submit(_fetch_and_record, batch_index + 1)
223245
else:
224246
future = None
247+
# Make the main CUDA stream wait for the background H2D copy
248+
# before any GPU kernel consumes the prefetched tensors
249+
if event is not None:
250+
torch.cuda.current_stream().wait_event(event)
225251
yield current
226252

227253
def __iter__(self) -> Generator[Any, None, None]:
@@ -243,7 +269,14 @@ def _onload_value(cls, intermediate: IntermediateValue) -> Any:
243269

244270
match value:
245271
case torch.Tensor():
246-
return value.to(device=device)
272+
# use non_blocking when source is pinned and target is CUDA so the
273+
# H2D DMA can overlap with GPU compute on a separate CUDA stream
274+
non_blocking = (
275+
value.is_pinned()
276+
and device is not None
277+
and torch.device(device).type == "cuda"
278+
)
279+
return value.to(device=device, non_blocking=non_blocking)
247280
case list():
248281
return [cls._onload_value(v) for v in value]
249282
case tuple():
@@ -287,6 +320,13 @@ def _offload_value(
287320
# move to offload if no hit
288321
offloaded = value.to(device=offload_device)
289322
if offloaded is not value: # avoid circular ref
323+
# pin CPU tensors so onload can use non_blocking DMA
324+
if (
325+
torch.device(offload_device).type == "cpu"
326+
and torch.cuda.is_available()
327+
and not offloaded.is_pinned()
328+
):
329+
offloaded = offloaded.pin_memory()
290330
cls.offload_values[value] = offloaded
291331

292332
return IntermediateValue(

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_batches(
3434
num_batches: int,
3535
input_names: list[str],
3636
desc: str,
37-
use_prefetch: bool = False,
37+
sequential_prefetch: bool = False,
3838
) -> Iterator[tuple[int, dict]]:
3939
"""
4040
Yield (batch_idx, inputs) with the next batch optionally prefetched in a
@@ -44,7 +44,7 @@ def _get_batches(
4444
"""
4545
batch_source = (
4646
activations.iter_prefetch(input_names)
47-
if use_prefetch
47+
if sequential_prefetch
4848
else activations.iter(input_names)
4949
)
5050
for batch_idx, inputs in tqdm(
@@ -131,22 +131,24 @@ def __call__(
131131
else:
132132
session.state.loss_masks = None
133133

134+
sequential_prefetch = getattr(dataset_args, "sequential_prefetch", False)
135+
session.state.sequential_prefetch = sequential_prefetch
136+
134137
for subgraph_index, subgraph in enumerate(subgraphs):
135138
# prepare tqdm description texts
136139
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
137140
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
138141

139142
# reduce memory movement by keeping modules onloaded
140143
num_batches = len(dataloader)
141-
use_prefetch = getattr(dataset_args, "sequential_prefetch", False)
142144
with disable_offloading():
143145
# do a preliminary pass to trigger modifier hooks
144146
for batch_idx, inputs in _get_batches(
145147
activations,
146148
num_batches,
147149
subgraph.input_names,
148150
calib_desc,
149-
use_prefetch,
151+
sequential_prefetch,
150152
):
151153
session.state.current_batch_idx = batch_idx
152154
subgraph.forward(model, **inputs)
@@ -161,7 +163,7 @@ def __call__(
161163
num_batches,
162164
subgraph.input_names,
163165
prop_desc,
164-
use_prefetch,
166+
sequential_prefetch,
165167
):
166168
output = subgraph.forward(model, **inputs)
167169
if subgraph_index < num_subgraphs - 1:

tests/llmcompressor/pipelines/test_cache.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -76,30 +76,6 @@ def batch_dicts_equal(a: dict, b: dict) -> bool:
7676
assert batch_dicts_equal(b_iter, b_prefetch), f"batch {i} differs"
7777

7878

79-
def deep_equal(a, b) -> bool:
80-
if type(a) is not type(b):
81-
return False
82-
83-
match a:
84-
case torch.Tensor():
85-
return torch.equal(a, b)
86-
case list() | tuple():
87-
if len(a) != len(b):
88-
return False
89-
return all(deep_equal(_a, _b) for _a, _b in zip(a, b))
90-
case dict():
91-
if a.keys() != b.keys():
92-
return False
93-
return all(deep_equal(a[key], b[key]) for key in a.keys())
94-
case _ if is_dataclass(a):
95-
a_dict = {field.name: getattr(a, field.name) for field in fields(a)}
96-
b_dict = {field.name: getattr(b, field.name) for field in fields(b)}
97-
98-
return deep_equal(a_dict, b_dict)
99-
case _:
100-
return a == b
101-
102-
10379
@pytest.mark.unit
10480
def test_fetch_inputs(sample_cache):
10581
fetched = sample_cache.fetch(0, ["input_ids", "attention_mask"])
@@ -187,6 +163,30 @@ def test_device_handling(sample_dataloader):
187163
assert fetched["hidden_states"].device.type == "cuda"
188164

189165

166+
def deep_equal(a, b) -> bool:
167+
if type(a) is not type(b):
168+
return False
169+
170+
match a:
171+
case torch.Tensor():
172+
return torch.equal(a, b)
173+
case list() | tuple():
174+
if len(a) != len(b):
175+
return False
176+
return all(deep_equal(_a, _b) for _a, _b in zip(a, b))
177+
case dict():
178+
if a.keys() != b.keys():
179+
return False
180+
return all(deep_equal(a[key], b[key]) for key in a.keys())
181+
case _ if is_dataclass(a):
182+
a_dict = {field.name: getattr(a, field.name) for field in fields(a)}
183+
b_dict = {field.name: getattr(b, field.name) for field in fields(b)}
184+
185+
return deep_equal(a_dict, b_dict)
186+
case _:
187+
return a == b
188+
189+
190190
def test_override_eq_mode():
191191
a = torch.tensor([1, 2, 3])
192192
b = a

0 commit comments

Comments
 (0)