|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +DeepSeek V3 Load Planner for DCP that handles grouped expert tensors. |
| 9 | +
|
| 10 | +This planner validates that grouped expert tensors can be formed from individual experts |
| 11 | +in the checkpoint before creating read items. |
| 12 | +""" |
| 13 | + |
| 14 | +import re |
| 15 | +from typing import Any, List, Optional |
| 16 | + |
| 17 | +from torch.distributed._tensor import DTensor |
| 18 | +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, _create_read_items |
| 19 | +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE |
| 20 | +from torch.distributed.checkpoint.planner import LoadPlan |
| 21 | +from torchtitan.models.deepseek_v3.model.metadata import ( |
| 22 | + DeepSeekV3Metadata, |
| 23 | +) |
| 24 | + |
| 25 | +class DeepSeekV3LoadPlanner(DefaultLoadPlanner): |
| 26 | + """Load planner for DeepSeek V3 that handles grouped expert tensor validation.""" |
| 27 | + |
| 28 | + def __init__(self): |
| 29 | + """Initialize the DeepSeek V3 load planner.""" |
| 30 | + super().__init__() |
| 31 | + self.valid_grouped_experts = set() |
| 32 | + |
| 33 | + def set_up_planner( |
| 34 | + self, |
| 35 | + state_dict: STATE_DICT_TYPE, |
| 36 | + metadata: Optional[DeepSeekV3Metadata] = None, |
| 37 | + is_coordinator: bool = False, |
| 38 | + ) -> None: |
| 39 | + super().set_up_planner(state_dict, metadata.sd_metadata, is_coordinator) |
| 40 | + # Build cache of valid grouped expert FQNs once during setup |
| 41 | + self.metadata = metadata.sd_metadata |
| 42 | + self.io_metadata = metadata.io_metadata |
| 43 | + self.valid_grouped_experts = self._build_valid_grouped_experts() |
| 44 | + |
| 45 | + def _build_valid_grouped_experts(self) -> set: |
| 46 | + """Build cache of valid grouped expert FQNs from checkpoint metadata.""" |
| 47 | + # Group individual experts by (layer, weight_type) |
| 48 | + experts_by_group = {} |
| 49 | + # Match only weight tensors, explicitly exclude scale tensors |
| 50 | + expert_pattern = r'model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(\w+)\.weight$' |
| 51 | + hf_to_tt_weight_map = {'gate_proj': 'w1', 'down_proj': 'w2', 'up_proj': 'w3'} |
| 52 | + |
| 53 | + # Count total expert entries |
| 54 | + total_expert_entries = 0 |
| 55 | + |
| 56 | + for idx in self.io_metadata.storage_data.keys(): |
| 57 | + match = re.match(expert_pattern, idx.fqn) |
| 58 | + if match: |
| 59 | + total_expert_entries += 1 |
| 60 | + layer_idx, expert_idx, hf_weight_type = match.groups() |
| 61 | + tt_weight_type = hf_to_tt_weight_map.get(hf_weight_type) |
| 62 | + |
| 63 | + if tt_weight_type: |
| 64 | + group_key = (layer_idx, tt_weight_type) |
| 65 | + if group_key not in experts_by_group: |
| 66 | + experts_by_group[group_key] = [] |
| 67 | + experts_by_group[group_key].append(int(expert_idx)) |
| 68 | + |
| 69 | + # If no expert entries found, the checkpoint might not have individual experts |
| 70 | + # This could mean experts are already grouped or use a different naming pattern |
| 71 | + if total_expert_entries == 0: |
| 72 | + return set() |
| 73 | + |
| 74 | + # Determine which grouped expert FQNs are valid |
| 75 | + # We just need to have at least one expert for each weight type in each layer |
| 76 | + valid_fqns = set() |
| 77 | + |
| 78 | + if len(experts_by_group) == 0: |
| 79 | + return set() |
| 80 | + |
| 81 | + for (layer_idx, tt_weight_type), expert_indices in experts_by_group.items(): |
| 82 | + expert_indices = sorted(expert_indices) |
| 83 | + |
| 84 | + # As long as we have at least one expert, we can create a grouped tensor |
| 85 | + if len(expert_indices) > 0: |
| 86 | + grouped_fqn = f"layers.{layer_idx}.moe.experts.{tt_weight_type}" |
| 87 | + valid_fqns.add(grouped_fqn) |
| 88 | + |
| 89 | + return valid_fqns |
| 90 | + |
| 91 | + def create_local_plan(self) -> LoadPlan: |
| 92 | + """Create a local load plan starting from the model's state_dict.""" |
| 93 | + requests = [] |
| 94 | + |
| 95 | + # Process each tensor in the model's state_dict |
| 96 | + for fqn, obj in self.state_dict.items(): |
| 97 | + if self._is_grouped_expert_tensor(fqn) and fqn not in self.valid_grouped_experts: |
| 98 | + raise RuntimeError(f"Grouped expert tensor {fqn} cannot be loaded from checkpoint") |
| 99 | + |
| 100 | + # Create read items for all tensors (both regular and grouped) |
| 101 | + self._validate_and_create_read_items(fqn, obj, self.metadata, requests) |
| 102 | + |
| 103 | + return LoadPlan(requests) |
| 104 | + |
| 105 | + def _validate_and_create_read_items(self, fqn: str, obj: Any, metadata: Any, requests: List) -> None: |
| 106 | + """Validate tensor and add read items to requests.""" |
| 107 | + if fqn not in metadata.state_dict_metadata: |
| 108 | + raise RuntimeError(f"Missing key in checkpoint metadata: {fqn}") |
| 109 | + |
| 110 | + md = metadata.state_dict_metadata[fqn] |
| 111 | + |
| 112 | + # Create read items (handle DTensor submesh) |
| 113 | + if isinstance(obj, DTensor): |
| 114 | + if obj.device_mesh.get_coordinate() is not None: |
| 115 | + requests += _create_read_items(fqn, md, obj) |
| 116 | + else: |
| 117 | + requests += _create_read_items(fqn, md, obj) |
| 118 | + |
| 119 | + def _is_grouped_expert_tensor(self, fqn: str) -> bool: |
| 120 | + """Check if this FQN represents a grouped expert tensor.""" |
| 121 | + # Match grouped expert tensors but exclude shared expert tensors |
| 122 | + return 'moe.experts' in fqn |
| 123 | + |
0 commit comments