Skip to content

Commit 6528e9d

Browse files
committed
fix loading error
1 parent e7f8607 commit 6528e9d

File tree

2 files changed

+8
-60
lines changed

2 files changed

+8
-60
lines changed

torchtitan/components/checkpoint.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -418,16 +418,7 @@ def dcp_load(
418418
)
419419

420420
state_dict = self.sd_adapter.from_hf(hf_state_dict)
421-
422-
# [rank0]:after sd converter, placement is DeviceMesh((dp_shard_mod_ep=2, dp_shard_in_ep=2, tp=2), device: 'cuda', stride: (4, 2, 1))
423-
print(
424-
f"after sd converter, placement is {state_dict['layers.3.moe.experts.w3'].device_mesh}, type {type(state_dict['layers.3.moe.experts.w3'])}, placement {state_dict['layers.3.moe.experts.w3'].placements}"
425-
)
426-
427-
# [rank0]:after sd converter, model placement is DeviceMesh((dp_shard_mod_ep=2, ep=2, tp=2), device: 'cuda', stride: (4, 2, 1))
428-
# model_state_dict = self.states[MODEL].state_dict()
429-
# print(f"after sd converter, model placement is {model_state_dict['layers.3.moe.experts.w3'].device_mesh}")
430-
421+
431422
self.states[MODEL].load_state_dict(state_dict)
432423
else:
433424
dcp.load(state_dict, checkpoint_id=checkpoint_id)

torchtitan/models/deepseek_v3/model/state_dict_adapter.py

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import re
9+
from threading import local
910
from typing import Any, Dict
1011

1112
import torch
@@ -276,54 +277,15 @@ def _get_local_experts_weights(
276277
local_expert_tensors[expert_key] = expert_dtensor
277278

278279
return local_expert_tensors
279-
280-
def _chunk_local_expert_weights(
281-
self,
282-
local_tensor: torch.Tensor,
283-
dtensor_placements: tuple,
284-
dtensor_shape: tuple,
285-
device_mesh: DeviceMesh,
286-
):
287-
"""
288-
Chunk the local individual experts weight, assemble back to GroupedExperts weights DTensor.
289-
290-
This method is a placeholder for future implementation of expert weight concatenation.
291-
292-
Args:
293-
local_tensor: Concatenated local individual expert weights
294-
"""
295-
296-
# Calculate the index range on dim-i to chunk
297-
for i in range(1, len(dtensor_placements)):
298-
dim_size = dtensor_shape[i]
299-
start_index, end_index = self._caculate_indices_from_placements(
300-
dim=i,
301-
dim_size=dim_size,
302-
dtensor_placements=dtensor_placements,
303-
device_mesh=device_mesh,
304-
)
305-
# No need to chunk on current dimension
306-
if start_index is None or end_index is None:
307-
continue
308-
309-
# Chunk local_tensor on dim-i
310-
local_tensor = local_tensor.narrow(i, start_index, end_index - start_index)
311-
312-
# Assemble DTensor
313-
grouped_expert_weights = DTensor.from_local(
314-
local_tensor, device_mesh, dtensor_placements, run_check=False
315-
)
316-
317-
return grouped_expert_weights
318-
280+
319281
def _concatenate_local_expert_weights(
320282
self,
321283
expert_weights_by_layer: dict[str, Any],
322284
abstract_key: str,
323285
device_mesh: DeviceMesh,
324286
) -> torch.Tensor:
325287
"""
326-
Concatenate the weights of separate experts into GroupedExperts weights.
288+
Try to concatenate the weights of separate experts into GroupedExperts weights.
327289
"""
328290
for layer in expert_weights_by_layer.keys():
329291
# If we have all the experts for this abstract_key, concatenate them
@@ -335,20 +297,15 @@ def _concatenate_local_expert_weights(
335297
if len(experts) == expected_n_experts:
336298
sorted_expert_ids = sorted(experts.keys())
337299
sorted_experts = [experts[i] for i in sorted_expert_ids]
338-
local_tensor = torch.stack(sorted_experts, dim=0)
339-
300+
local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor
301+
340302
assert (
341303
abstract_key in self.grouped_expert_weight_placements
342304
and abstract_key in self.grouped_expert_weight_shape
343305
), f"GroupedExperts weight metadata {self.grouped_expert_weight_placements} {self.grouped_expert_weight_shape} can not be None!"
344306

345-
stacked_dtensor = self._chunk_local_expert_weights(
346-
local_tensor,
347-
dtensor_placements=self.grouped_expert_weight_placements[
348-
abstract_key
349-
],
350-
dtensor_shape=self.grouped_expert_weight_shape[abstract_key],
351-
device_mesh=device_mesh,
307+
stacked_dtensor = DTensor.from_local(
308+
local_tensor, device_mesh, self.grouped_expert_weight_placements[abstract_key], run_check=False
352309
)
353310

354311
# Remove these experts from the tracking dict to free memory

0 commit comments

Comments
 (0)