Skip to content

Commit 22b62b9

Browse files
author
Avishek Goswami
committed
Address review: move deep_equal, make prefetch a global setting via State
Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 9128c2b commit 22b62b9

File tree

4 files changed

+34
-36
lines changed

4 files changed

+34
-36
lines changed

src/llmcompressor/core/state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class State:
114114
_last_log_step: float | int | None = None
115115
loss_masks: list[torch.Tensor] | None = None
116116
current_batch_idx: int = -1
117+
sequential_prefetch: bool = False
117118

118119
@property
119120
def compression_ready(self) -> bool:

src/llmcompressor/modifiers/awq/base.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ class AWQModifier(Modifier, QuantizationMixin):
138138
requirements but requires more time to move data between cpu and execution
139139
device. Defaults to None, so cached args are not offloaded. Consider setting
140140
to torch.device("cpu") if you are encountering OOM errors
141-
:param prefetch: when offloading, prefetch the next batch in a background thread
142-
to overlap CPU-to-device onload with the forward pass, reducing wall-clock
143-
time. Default False; set True when offload_device is set and GPU memory
144-
allows two batches on device simultaneously
145141
:param duo_scaling: whether to use duo scaling, which uses both input activations
146142
and weights to determine the scaling factor. Defaults to True
147143
If True, both activations and weights are used.
@@ -161,7 +157,6 @@ class AWQModifier(Modifier, QuantizationMixin):
161157
sequential_targets: str | list[str] | None = None
162158
mappings: list[AWQMapping] | None = None
163159
offload_device: torch.device | None | Sentinel = Sentinel("not_provided")
164-
prefetch: bool = False
165160
duo_scaling: bool | Literal["both"] = True
166161
n_grid: int = 20
167162

@@ -613,8 +608,8 @@ def _smooth(
613608
@torch.no_grad()
614609
def _run_samples(self, module: Module) -> list[torch.Tensor]:
615610
cache = self._parent_args_cache[module]
616-
# Prefetch overlaps CPU->device onload with forward pass when offloading.
617-
batch_iter = cache.iter_prefetch() if self.prefetch else cache
611+
use_prefetch = active_session().state.sequential_prefetch
612+
batch_iter = cache.iter_prefetch() if use_prefetch else cache
618613
outputs = [module(**batch_kwargs) for batch_kwargs in batch_iter]
619614
return [
620615
# If tuple, assume that first argument is the input

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_batches(
3434
num_batches: int,
3535
input_names: list[str],
3636
desc: str,
37-
use_prefetch: bool = False,
37+
sequential_prefetch: bool = False,
3838
) -> Iterator[tuple[int, dict]]:
3939
"""
4040
Yield (batch_idx, inputs) with the next batch optionally prefetched in a
@@ -44,7 +44,7 @@ def _get_batches(
4444
"""
4545
batch_source = (
4646
activations.iter_prefetch(input_names)
47-
if use_prefetch
47+
if sequential_prefetch
4848
else activations.iter(input_names)
4949
)
5050
for batch_idx, inputs in tqdm(
@@ -131,22 +131,24 @@ def __call__(
131131
else:
132132
session.state.loss_masks = None
133133

134+
sequential_prefetch = getattr(dataset_args, "sequential_prefetch", False)
135+
session.state.sequential_prefetch = sequential_prefetch
136+
134137
for subgraph_index, subgraph in enumerate(subgraphs):
135138
# prepare tqdm description texts
136139
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
137140
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
138141

139142
# reduce memory movement by keeping modules onloaded
140143
num_batches = len(dataloader)
141-
use_prefetch = getattr(dataset_args, "sequential_prefetch", False)
142144
with disable_offloading():
143145
# do a preliminary pass to trigger modifier hooks
144146
for batch_idx, inputs in _get_batches(
145147
activations,
146148
num_batches,
147149
subgraph.input_names,
148150
calib_desc,
149-
use_prefetch,
151+
sequential_prefetch,
150152
):
151153
session.state.current_batch_idx = batch_idx
152154
subgraph.forward(model, **inputs)
@@ -161,7 +163,7 @@ def __call__(
161163
num_batches,
162164
subgraph.input_names,
163165
prop_desc,
164-
use_prefetch,
166+
sequential_prefetch,
165167
):
166168
output = subgraph.forward(model, **inputs)
167169
if subgraph_index < num_subgraphs - 1:

tests/llmcompressor/pipelines/test_cache.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -76,30 +76,6 @@ def batch_dicts_equal(a: dict, b: dict) -> bool:
7676
assert batch_dicts_equal(b_iter, b_prefetch), f"batch {i} differs"
7777

7878

79-
def deep_equal(a, b) -> bool:
80-
if type(a) is not type(b):
81-
return False
82-
83-
match a:
84-
case torch.Tensor():
85-
return torch.equal(a, b)
86-
case list() | tuple():
87-
if len(a) != len(b):
88-
return False
89-
return all(deep_equal(_a, _b) for _a, _b in zip(a, b))
90-
case dict():
91-
if a.keys() != b.keys():
92-
return False
93-
return all(deep_equal(a[key], b[key]) for key in a.keys())
94-
case _ if is_dataclass(a):
95-
a_dict = {field.name: getattr(a, field.name) for field in fields(a)}
96-
b_dict = {field.name: getattr(b, field.name) for field in fields(b)}
97-
98-
return deep_equal(a_dict, b_dict)
99-
case _:
100-
return a == b
101-
102-
10379
@pytest.mark.unit
10480
def test_fetch_inputs(sample_cache):
10581
fetched = sample_cache.fetch(0, ["input_ids", "attention_mask"])
@@ -187,6 +163,30 @@ def test_device_handling(sample_dataloader):
187163
assert fetched["hidden_states"].device.type == "cuda"
188164

189165

166+
def deep_equal(a, b) -> bool:
167+
if type(a) is not type(b):
168+
return False
169+
170+
match a:
171+
case torch.Tensor():
172+
return torch.equal(a, b)
173+
case list() | tuple():
174+
if len(a) != len(b):
175+
return False
176+
return all(deep_equal(_a, _b) for _a, _b in zip(a, b))
177+
case dict():
178+
if a.keys() != b.keys():
179+
return False
180+
return all(deep_equal(a[key], b[key]) for key in a.keys())
181+
case _ if is_dataclass(a):
182+
a_dict = {field.name: getattr(a, field.name) for field in fields(a)}
183+
b_dict = {field.name: getattr(b, field.name) for field in fields(b)}
184+
185+
return deep_equal(a_dict, b_dict)
186+
case _:
187+
return a == b
188+
189+
190190
def test_override_eq_mode():
191191
a = torch.tensor([1, 2, 3])
192192
b = a

0 commit comments

Comments
 (0)