@@ -81,33 +81,17 @@ def call_observer(
81
81
base_name is "weight", then the module's weight tensor will be used
82
82
"""
83
83
with align_module_device (module ):
84
- if base_name == "weight" :
85
- value = module .weight
86
- g_idx = getattr (module , "weight_g_idx" , None )
87
- elif value is not None :
88
- g_idx = None
89
- else :
90
- raise ValueError (
91
- "Must provide a value to observe if not using weight observer"
92
- )
93
-
84
+ value = module .weight if base_name == "weight" else value
94
85
observer : Observer = getattr (module , f"{ base_name } _observer" )
95
86
96
87
if should_calculate_gparam :
97
88
global_scale = observer .get_global_scale (value )
98
89
update_offload_parameter (module , f"{ base_name } _global_scale" , global_scale )
99
- else :
100
- global_scale = getattr (module , f"{ base_name } _global_scale" , None )
101
90
102
91
if should_calculate_qparams :
103
- updated_scale , updated_zero_point = observer (
104
- value , g_idx = g_idx , global_scale = global_scale
105
- )
106
- # register or update scale & zero_point parameters (supports block shapes)
107
- scale_name = f"{ base_name } _scale"
108
- zp_name = f"{ base_name } _zero_point"
109
- update_offload_parameter (module , scale_name , updated_scale )
110
- update_offload_parameter (module , zp_name , updated_zero_point )
92
+ scale , zero_point = observer (value )
93
+ update_offload_parameter (module , f"{ base_name } _scale" , scale )
94
+ update_offload_parameter (module , f"{ base_name } _zero_point" , zero_point )
111
95
112
96
113
97
def update_weight_global_scale (module : Module ):
0 commit comments