Skip to content

Commit 595fd90

Browse files
author
Avishek Goswami
committed
Address review: move deep_equal, make prefetch a global setting via State
Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 9128c2b commit 595fd90

File tree

4 files changed

+30
-32
lines changed

4 files changed

+30
-32
lines changed

src/llmcompressor/core/state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class State:
114114
_last_log_step: float | int | None = None
115115
loss_masks: list[torch.Tensor] | None = None
116116
current_batch_idx: int = -1
117+
sequential_prefetch: bool = False
117118

118119
@property
119120
def compression_ready(self) -> bool:

src/llmcompressor/modifiers/awq/base.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ class AWQModifier(Modifier, QuantizationMixin):
138138
requirements but requires more time to move data between cpu and execution
139139
device. Defaults to None, so cached args are not offloaded. Consider setting
140140
to torch.device("cpu") if you are encountering OOM errors
141-
:param prefetch: when offloading, prefetch the next batch in a background thread
142-
to overlap CPU-to-device onload with the forward pass, reducing wall-clock
143-
time. Default False; set True when offload_device is set and GPU memory
144-
allows two batches on device simultaneously
145141
:param duo_scaling: whether to use duo scaling, which uses both input activations
146142
and weights to determine the scaling factor. Defaults to True
147143
If True, both activations and weights are used.
@@ -161,7 +157,6 @@ class AWQModifier(Modifier, QuantizationMixin):
161157
sequential_targets: str | list[str] | None = None
162158
mappings: list[AWQMapping] | None = None
163159
offload_device: torch.device | None | Sentinel = Sentinel("not_provided")
164-
prefetch: bool = False
165160
duo_scaling: bool | Literal["both"] = True
166161
n_grid: int = 20
167162

@@ -613,8 +608,8 @@ def _smooth(
613608
@torch.no_grad()
614609
def _run_samples(self, module: Module) -> list[torch.Tensor]:
615610
cache = self._parent_args_cache[module]
616-
# Prefetch overlaps CPU->device onload with forward pass when offloading.
617-
batch_iter = cache.iter_prefetch() if self.prefetch else cache
611+
use_prefetch = active_session().state.sequential_prefetch
612+
batch_iter = cache.iter_prefetch() if use_prefetch else cache
618613
outputs = [module(**batch_kwargs) for batch_kwargs in batch_iter]
619614
return [
620615
# If tuple, assume that first argument is the input

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,16 @@ def __call__(
131131
else:
132132
session.state.loss_masks = None
133133

134+
use_prefetch = getattr(dataset_args, "sequential_prefetch", False)
135+
session.state.sequential_prefetch = use_prefetch
136+
134137
for subgraph_index, subgraph in enumerate(subgraphs):
135138
# prepare tqdm description texts
136139
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
137140
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
138141

139142
# reduce memory movement by keeping modules onloaded
140143
num_batches = len(dataloader)
141-
use_prefetch = getattr(dataset_args, "sequential_prefetch", False)
142144
with disable_offloading():
143145
# do a preliminary pass to trigger modifier hooks
144146
for batch_idx, inputs in _get_batches(

tests/llmcompressor/pipelines/test_cache.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,30 @@ def sample_cache(sample_dataloader):
4141
]
4242

4343

44+
def deep_equal(a, b) -> bool:
45+
if type(a) is not type(b):
46+
return False
47+
48+
match a:
49+
case torch.Tensor():
50+
return torch.equal(a, b)
51+
case list() | tuple():
52+
if len(a) != len(b):
53+
return False
54+
return all(deep_equal(_a, _b) for _a, _b in zip(a, b))
55+
case dict():
56+
if a.keys() != b.keys():
57+
return False
58+
return all(deep_equal(a[key], b[key]) for key in a.keys())
59+
case _ if is_dataclass(a):
60+
a_dict = {field.name: getattr(a, field.name) for field in fields(a)}
61+
b_dict = {field.name: getattr(b, field.name) for field in fields(b)}
62+
63+
return deep_equal(a_dict, b_dict)
64+
case _:
65+
return a == b
66+
67+
4468
@pytest.mark.unit
4569
def test_initialization(sample_dataloader):
4670
cache = IntermediatesCache.from_dataloader(
@@ -76,30 +100,6 @@ def batch_dicts_equal(a: dict, b: dict) -> bool:
76100
assert batch_dicts_equal(b_iter, b_prefetch), f"batch {i} differs"
77101

78102

79-
def deep_equal(a, b) -> bool:
80-
if type(a) is not type(b):
81-
return False
82-
83-
match a:
84-
case torch.Tensor():
85-
return torch.equal(a, b)
86-
case list() | tuple():
87-
if len(a) != len(b):
88-
return False
89-
return all(deep_equal(_a, _b) for _a, _b in zip(a, b))
90-
case dict():
91-
if a.keys() != b.keys():
92-
return False
93-
return all(deep_equal(a[key], b[key]) for key in a.keys())
94-
case _ if is_dataclass(a):
95-
a_dict = {field.name: getattr(a, field.name) for field in fields(a)}
96-
b_dict = {field.name: getattr(b, field.name) for field in fields(b)}
97-
98-
return deep_equal(a_dict, b_dict)
99-
case _:
100-
return a == b
101-
102-
103103
@pytest.mark.unit
104104
def test_fetch_inputs(sample_cache):
105105
fetched = sample_cache.fetch(0, ["input_ids", "attention_mask"])

0 commit comments

Comments
 (0)