|
18 | 18 | from collections import OrderedDict |
19 | 19 | from io import BytesIO |
20 | 20 | from pathlib import Path |
21 | | -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union |
| 21 | +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
22 | 22 | from urllib.parse import urlparse |
23 | 23 |
|
24 | 24 | import numpy |
@@ -1064,14 +1064,25 @@ def preserve_attr(base: object, attr: str): |
1064 | 1064 |
|
1065 | 1065 |
|
1066 | 1066 | @contextlib.contextmanager |
1067 | | -def align_modules(modules: Iterable[torch.nn.Module]): |
| 1067 | +def align_modules( |
| 1068 | + modules: Iterable[torch.nn.Module], execution_device: Optional[torch.device] = None |
| 1069 | +): |
| 1070 | + original_devices = {} |
1068 | 1071 | can_offload = [module for module in modules if has_offloaded_params(module)] |
| 1072 | + |
1069 | 1073 | for module in can_offload: |
| 1074 | + if execution_device is not None: |
| 1075 | + module._hf_hook.execution_device = execution_device |
| 1076 | + original_devices[module] = module._hf_hook.execution_device |
| 1077 | + |
1070 | 1078 | module._hf_hook.pre_forward(module) |
1071 | 1079 | module._hf_hook.offload = False |
1072 | 1080 |
|
1073 | 1081 | yield |
1074 | 1082 |
|
1075 | 1083 | for module in can_offload: |
1076 | | - module._hf_hook.post_forward(module, None) |
| 1084 | + if execution_device is not None: |
| 1085 | + module._hf_hook.execution_device = original_devices[module] |
| 1086 | + |
1077 | 1087 | module._hf_hook.offload = True |
| 1088 | + module._hf_hook.post_forward(module, None) |
0 commit comments