Skip to content

Commit 6a3733a

Browse files
committed
fix bug in align_modules
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ad7e3ac commit 6a3733a

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/llmcompressor/utils/helpers.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections import OrderedDict
1919
from io import BytesIO
2020
from pathlib import Path
21-
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
21+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2222
from urllib.parse import urlparse
2323

2424
import numpy
@@ -1064,14 +1064,25 @@ def preserve_attr(base: object, attr: str):
10641064

10651065

10661066
@contextlib.contextmanager
1067-
def align_modules(modules: Iterable[torch.nn.Module]):
1067+
def align_modules(
1068+
modules: Iterable[torch.nn.Module], execution_device: Optional[torch.device] = None
1069+
):
1070+
original_devices = {}
10681071
can_offload = [module for module in modules if has_offloaded_params(module)]
1072+
10691073
for module in can_offload:
1074+
if execution_device is not None:
1075+
module._hf_hook.execution_device = execution_device
1076+
original_devices[module] = module._hf_hook.execution_device
1077+
10701078
module._hf_hook.pre_forward(module)
10711079
module._hf_hook.offload = False
10721080

10731081
yield
10741082

10751083
for module in can_offload:
1076-
module._hf_hook.post_forward(module, None)
1084+
if execution_device is not None:
1085+
module._hf_hook.execution_device = original_devices[module]
1086+
10771087
module._hf_hook.offload = True
1088+
module._hf_hook.post_forward(module, None)

0 commit comments

Comments
 (0)