Skip to content

Commit 6d55c1a

Browse files
committed
WIP: need to implement offloaded module replacement
Signed-off-by: Kyle Sayers <[email protected]>
1 parent ad67532 commit 6d55c1a

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

src/llmcompressor/modeling/prepare.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import tqdm
22
import contextlib
3-
from compressed_tensors.utils import replace_module, match_named_modules
3+
from compressed_tensors.utils import replace_module, delete_offload_module, register_offload_module
44
from transformers import PreTrainedModel
55

66
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
@@ -47,22 +47,26 @@ def update_qwen3_moe(model, stack):
4747

4848
def update_gpt_oss_moe(model: PreTrainedModel, stack):
4949
@contextlib.contextmanager
50-
def replace_context(model, name, module):
50+
def replace_context(parent, name, module):
5151
linear = GptOssExpertsLinear(module)
52-
replace_module(model, name, linear)
5352
del module
53+
delete_offload_module(parent, name)
54+
register_offload_module(parent, name, linear)
5455

5556
yield
5657

5758
restored = linear.to_original()
58-
replace_module(model, name, restored)
59+
del linear
60+
delete_offload_module(parent, name)
61+
register_offload_module(parent, name, restored)
5962

60-
# TODO: need to think about duplicates
63+
# TODO: need to consider when replace module is duplicated in structure
6164
modules = list(model.named_modules())
6265
for name, module in tqdm.tqdm(modules, desc="Checking modules for replacements"):
63-
cls_name = module.__class__.__name__
64-
if cls_name == "GptOssExperts":
65-
stack.enter_context(replace_context(model, name, module))
66+
children = list(module.named_children())
67+
for child_name, child in children:
68+
if child.__class__.__name__ == "GptOssExperts":
69+
stack.enter_context(replace_context(module, child_name, child))
6670

6771

6872

@@ -78,3 +82,27 @@ def moe_calibration_context(model: PreTrainedModel, stack):
7882
cls_name = model.__class__.__name__
7983
if cls_name in moe_context:
8084
moe_context.get(cls_name)(model, stack)
85+
86+
87+
88+
# import torch
89+
# from accelerate.hooks import (
90+
# AlignDevicesHook,
91+
# def replace_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
92+
# hook = getattr(base, name)._hf_hook
93+
# delete_offload_module(base, name)
94+
95+
# weights_map = PrefixedDataset(
96+
# hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix.remove_suffix(name + ".")}"
97+
# )
98+
99+
# parent_hook = AlignDevicesHook(
100+
# execution_device=hook.execution_device,
101+
# offload=hook.offload,
102+
# io_same_device=False,
103+
# weights_map=weights_map,
104+
# offload_buffers=offload_buffers,
105+
# place_submodules=place_submodules,
106+
# skip_keys=None,
107+
# tied_params_map=hook.tied_params_map,
108+
# )

0 commit comments

Comments
 (0)