@@ -326,13 +326,13 @@ def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]:
326
326
for key , weight in state_dict .items ():
327
327
if key .endswith (".weight" ) and key + "_scale_inv" in state_dict :
328
328
scale_inv = state_dict [key + "_scale_inv" ]
329
- # dequantized_weight = dequantize_from_fp8(
330
- # weight, scale_inv, dtype=torch.float32
331
- # )
332
- # # update the weight and remove the scale_inv tensor
333
- # state_dict[key] = dequantized_weight
329
+ dequantized_weight = dequantize_from_fp8 (
330
+ weight , scale_inv , dtype = torch .float32
331
+ )
332
+ # update the weight and remove the scale_inv tensor
333
+ state_dict [key ] = dequantized_weight
334
334
335
- state_dict [key ] = weight
335
+ # state_dict[key] = weight
336
336
scale_inv_keys .append (key + "_scale_inv" )
337
337
338
338
for key in scale_inv_keys :
@@ -452,7 +452,15 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
452
452
stacked_value = self ._concatenate_local_expert_weights (
453
453
expert_weights_by_layer , titan_abstract_key , value .device_mesh
454
454
)
455
+
455
456
if stacked_value is not None :
457
+ local_tensor = stacked_value ._local_tensor
458
+
459
+ tensor_list = local_tensor .tolist ()
460
+ # Save to JSON file
461
+ import json
462
+ with open (f'my_implementation_tensor_{ new_key } .json' , 'w' ) as f :
463
+ json .dump (tensor_list , f )
456
464
state_dict [new_key ] = stacked_value
457
465
458
466
elif "layers" in key :
0 commit comments