1
1
import tqdm
2
2
import contextlib
3
- from compressed_tensors .utils import replace_module , match_named_modules
3
+ from compressed_tensors .utils import replace_module , delete_offload_module , register_offload_module
4
4
from transformers import PreTrainedModel
5
5
6
6
from llmcompressor .modeling .deepseek_v3 import replace as replace_deepseekv3
@@ -47,22 +47,26 @@ def update_qwen3_moe(model, stack):
47
47
48
48
def update_gpt_oss_moe (model : PreTrainedModel , stack ):
49
49
@contextlib .contextmanager
50
- def replace_context (model , name , module ):
50
+ def replace_context (parent , name , module ):
51
51
linear = GptOssExpertsLinear (module )
52
- replace_module (model , name , linear )
53
52
del module
53
+ delete_offload_module (parent , name )
54
+ register_offload_module (parent , name , linear )
54
55
55
56
yield
56
57
57
58
restored = linear .to_original ()
58
- replace_module (model , name , restored )
59
+ del linear
60
+ delete_offload_module (parent , name )
61
+ register_offload_module (parent , name , restored )
59
62
60
- # TODO: need to think about duplicates
63
+ # TODO: need to consider when replace module is duplicated in structure
61
64
modules = list (model .named_modules ())
62
65
for name , module in tqdm .tqdm (modules , desc = "Checking modules for replacements" ):
63
- cls_name = module .__class__ .__name__
64
- if cls_name == "GptOssExperts" :
65
- stack .enter_context (replace_context (model , name , module ))
66
+ children = list (module .named_children ())
67
+ for child_name , child in children :
68
+ if child .__class__ .__name__ == "GptOssExperts" :
69
+ stack .enter_context (replace_context (module , child_name , child ))
66
70
67
71
68
72
@@ -78,3 +82,27 @@ def moe_calibration_context(model: PreTrainedModel, stack):
78
82
cls_name = model .__class__ .__name__
79
83
if cls_name in moe_context :
80
84
moe_context .get (cls_name )(model , stack )
85
+
86
+
87
+
88
+ # import torch
89
+ # from accelerate.hooks import (
90
+ # AlignDevicesHook,
91
+ # def replace_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
92
+ # hook = getattr(base, name)._hf_hook
93
+ # delete_offload_module(base, name)
94
+
95
+ # weights_map = PrefixedDataset(
96
+ # hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix.remove_suffix(name + ".")}"
97
+ # )
98
+
99
+ # parent_hook = AlignDevicesHook(
100
+ # execution_device=hook.execution_device,
101
+ # offload=hook.offload,
102
+ # io_same_device=False,
103
+ # weights_map=weights_map,
104
+ # offload_buffers=offload_buffers,
105
+ # place_submodules=place_submodules,
106
+ # skip_keys=None,
107
+ # tied_params_map=hook.tied_params_map,
108
+ # )
0 commit comments