Skip to content

Commit e5780c5

Browse files
authored
[Utils] Replace preserve_attr with patch_attr (#1187)
## Purpose ## * Provide explicit patching functionality to `preserve_attr` * This function is very similar to `unittest.mock.patch`, except that using this functionality does not require using the `unittest` library in source code ## Changes ## * Replace `preserve_attr` with `patch_attr` * Replace usage in `src/llmcompressor/pipelines/sequential/helpers.py`, which helps with clarity * Add unit test mark to utils/helpers pytests ## Testing ## * Added tests --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 73eb664 commit e5780c5

File tree

3 files changed

+60
-14
lines changed

3 files changed

+60
-14
lines changed

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from transformers.utils.fx import HFTracer
1414

1515
from llmcompressor.modifiers.utils.hooks import HooksMixin
16-
from llmcompressor.utils.helpers import calibration_forward_context, preserve_attr
16+
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
1717

1818
__all__ = ["trace_subgraphs", "Subgraph"]
1919

@@ -132,15 +132,14 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
132132

133133
def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
134134
if isinstance(root, Module):
135-
with preserve_attr(type(root), "forward"):
136-
# due to a bug in Tracer.create_args_for_root (_patch_function),
137-
# we must unwrap function wrappers prior to tracing, for example
138-
# the `deprecate_kwarg` by transformers which wraps forward
139-
140-
# we override the class method because the
141-
# class method is the one being traced
142-
type(root).forward = inspect.unwrap(type(root).forward)
143-
135+
# due to a bug in Tracer.create_args_for_root (_patch_function),
136+
# we must unwrap function wrappers prior to tracing, for example
137+
# the `deprecate_kwarg` by transformers which wraps forward
138+
unwrapped_forward = inspect.unwrap(type(root).forward)
139+
140+
# we override the class method because the
141+
# class method is the one being traced
142+
with patch_attr(type(root), "forward", unwrapped_forward):
144143
return super().trace(root, *args, **kwargs)
145144

146145
else:

src/llmcompressor/utils/helpers.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
"DisableQuantization",
6666
"eval_context",
6767
"calibration_forward_context",
68-
"preserve_attr",
68+
"patch_attr",
6969
]
7070

7171

@@ -1051,9 +1051,29 @@ def calibration_forward_context(model: PreTrainedModel):
10511051

10521052

10531053
@contextlib.contextmanager
1054-
def preserve_attr(base: object, attr: str):
1055-
value = getattr(base, attr)
1054+
def patch_attr(base: object, attr: str, value: Any):
1055+
"""
1056+
Patch the value of an object attribute. Original value is restored upon exit
1057+
1058+
:param base: object which has the attribute to patch
1059+
:param attr: name of the the attribute to patch
1060+
:param value: used to replace original value
1061+
1062+
Usage:
1063+
>>> from types import SimpleNamespace
1064+
>>> obj = SimpleNamespace()
1065+
>>> with patch_attr(obj, "attribute", "value"):
1066+
... assert obj.attribute == "value"
1067+
>>> assert not hasattr(obj, "attribute")
1068+
"""
1069+
_sentinel = object()
1070+
original_value = getattr(base, attr, _sentinel)
1071+
1072+
setattr(base, attr, value)
10561073
try:
10571074
yield
10581075
finally:
1059-
setattr(base, attr, value)
1076+
if original_value is not _sentinel:
1077+
setattr(base, attr, original_value)
1078+
else:
1079+
delattr(base, attr)

tests/llmcompressor/utils/test_helpers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
flatten_iterable,
1212
getattr_chain,
1313
interpolate,
14+
patch_attr,
1415
validate_str_iterable,
1516
)
1617

1718

19+
@pytest.mark.unit
1820
@pytest.mark.parametrize(
1921
"test_list,output",
2022
[
@@ -29,6 +31,7 @@ def test_flatten_iterable(test_list, output):
2931
assert flattened == output
3032

3133

34+
@pytest.mark.unit
3235
@pytest.mark.parametrize(
3336
"test_bool,output",
3437
[
@@ -53,6 +56,7 @@ def test_convert_to_bool(test_bool, output):
5356
assert converted == output
5457

5558

59+
@pytest.mark.unit
5660
@pytest.mark.parametrize(
5761
"test_list,output",
5862
[
@@ -68,11 +72,13 @@ def test_validate_str_iterable(test_list, output):
6872
assert validated == output
6973

7074

75+
@pytest.mark.unit
7176
def test_validate_str_iterable_negative():
7277
with pytest.raises(ValueError):
7378
validate_str_iterable("will fail", "")
7479

7580

81+
@pytest.mark.unit
7682
@pytest.mark.parametrize(
7783
"x_cur,x0,x1,y0,y1,inter_func,out",
7884
[
@@ -92,6 +98,7 @@ def test_interpolate(x_cur, x0, x1, y0, y1, inter_func, out):
9298
assert abs(out - interpolated) < 0.01
9399

94100

101+
@pytest.mark.unit
95102
def test_getattr_chain():
96103
base = SimpleNamespace()
97104
base.a = None
@@ -123,13 +130,15 @@ def test_getattr_chain():
123130
getattr_chain(base, "b.d.dne")
124131

125132

133+
@pytest.mark.unit
126134
def test_DisableQuantization():
127135
model = torch.nn.Linear(1, 1)
128136
with DisableQuantization(model):
129137
assert not model.quantization_enabled
130138
assert model.quantization_enabled
131139

132140

141+
@pytest.mark.unit
133142
def test_calibration_forward_context():
134143
model = torch.nn.Linear(1, 1)
135144
model.config = SimpleNamespace()
@@ -143,3 +152,21 @@ def test_calibration_forward_context():
143152
assert torch.is_grad_enabled()
144153
assert model.config.use_cache
145154
assert model.training
155+
156+
157+
@pytest.mark.unit
158+
def test_patch_attr():
159+
# patch, original value
160+
obj = SimpleNamespace()
161+
obj.attribute = "original"
162+
with patch_attr(obj, "attribute", "patched"):
163+
assert obj.attribute == "patched"
164+
obj.attribute = "modified"
165+
assert obj.attribute == "original"
166+
167+
# patch, no original attribute
168+
obj = SimpleNamespace()
169+
with patch_attr(obj, "attribute", "patched"):
170+
assert obj.attribute == "patched"
171+
obj.attribute = "modified"
172+
assert not hasattr(obj, "attribute")

0 commit comments

Comments
 (0)