Skip to content

Commit c56bac1

Browse files
committed
fix assemble algo
1 parent 6528e9d commit c56bac1

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

torchtitan/models/deepseek_v3/model/state_dict_adapter.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,13 @@ def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]:
326326
for key, weight in state_dict.items():
327327
if key.endswith(".weight") and key + "_scale_inv" in state_dict:
328328
scale_inv = state_dict[key + "_scale_inv"]
329-
# dequantized_weight = dequantize_from_fp8(
330-
# weight, scale_inv, dtype=torch.float32
331-
# )
332-
# # update the weight and remove the scale_inv tensor
333-
# state_dict[key] = dequantized_weight
329+
dequantized_weight = dequantize_from_fp8(
330+
weight, scale_inv, dtype=torch.float32
331+
)
332+
# update the weight and remove the scale_inv tensor
333+
state_dict[key] = dequantized_weight
334334

335-
state_dict[key] = weight
335+
# state_dict[key] = weight
336336
scale_inv_keys.append(key + "_scale_inv")
337337

338338
for key in scale_inv_keys:
@@ -452,7 +452,15 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
452452
stacked_value = self._concatenate_local_expert_weights(
453453
expert_weights_by_layer, titan_abstract_key, value.device_mesh
454454
)
455+
455456
if stacked_value is not None:
457+
local_tensor = stacked_value._local_tensor
458+
459+
tensor_list = local_tensor.tolist()
460+
# Save to JSON file
461+
import json
462+
with open(f'my_implementation_tensor_{new_key}.json', 'w') as f:
463+
json.dump(tensor_list, f)
456464
state_dict[new_key] = stacked_value
457465

458466
elif "layers" in key:

0 commit comments

Comments
 (0)