|
26 | 26 | from torch.distributed.checkpoint._consolidate_hf_safetensors import ( |
27 | 27 | consolidate_safetensors_files_on_every_rank, |
28 | 28 | ) |
| 29 | +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner |
29 | 30 | from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions |
30 | 31 | from torch.distributed.checkpoint.state_dict import ( |
31 | 32 | get_model_state_dict, |
@@ -73,8 +74,33 @@ def _get_state_dict(self) -> dict[str, Any]: |
73 | 74 | } |
74 | 75 | return state_dict |
75 | 76 |
|
| 77 | + def _is_converter_key(self, key: str) -> bool: |
| 78 | + """Check if a state dict key was added by a model converter.""" |
| 79 | + for part in self.model: |
| 80 | + fn = getattr(part, "converter_key_filter", None) |
| 81 | + if fn is not None and fn(key): |
| 82 | + return True |
| 83 | + return False |
| 84 | + |
| 85 | + def _save_converter_keys_only(self) -> bool: |
| 86 | + """Check if any model part requests saving only converter-added weights.""" |
| 87 | + return any( |
| 88 | + getattr(part, "save_converter_keys_only", False) for part in self.model |
| 89 | + ) |
| 90 | + |
| 91 | + def state_dict_to_save(self) -> dict[str, Any]: |
| 92 | + full_sd = self._get_state_dict() |
| 93 | + if self._save_converter_keys_only(): |
| 94 | + return {k: v for k, v in full_sd.items() if self._is_converter_key(k)} |
| 95 | + return full_sd |
| 96 | + |
| 97 | + def base_state_dict(self) -> dict[str, Any]: |
| 98 | + """Return state dict with only the original model keys (before converters).""" |
| 99 | + full_sd = self._get_state_dict() |
| 100 | + return {k: v for k, v in full_sd.items() if not self._is_converter_key(k)} |
| 101 | + |
76 | 102 | def state_dict(self) -> dict[str, Any]: |
77 | | - return self.cache_state_dict |
| 103 | + return self.state_dict_to_save() |
78 | 104 |
|
79 | 105 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
80 | 106 | func = functools.partial( |
@@ -321,6 +347,14 @@ class Config(Configurable.Config): |
321 | 347 | This will load the model only, excluding the specified keys. |
322 | 348 | """ |
323 | 349 |
|
| 350 | + additional_load_paths: list[str] = field(default_factory=list) |
| 351 | + """ |
| 352 | + Additional checkpoint paths to load from after the primary checkpoint. |
| 353 | + Useful for loading state dicts from multiple sources, e.g., base model |
| 354 | + weights from one checkpoint and LoRA adapter weights from another. |
| 355 | + Each path should contain a valid DCP checkpoint directory. |
| 356 | + """ |
| 357 | + |
324 | 358 | enable_first_step_checkpoint: bool = False |
325 | 359 | """ |
326 | 360 | Enable the checkpoint save at first step. This will save a checkpoint immediately |
@@ -445,6 +479,7 @@ def load_state_dict(state_dict): |
445 | 479 | self.sd_adapter = sd_adapter |
446 | 480 | self.export_dtype = TORCH_DTYPE_MAP[config.export_dtype] |
447 | 481 | self.exclude_from_loading = config.exclude_from_loading |
| 482 | + self.additional_load_paths = config.additional_load_paths |
448 | 483 | self.interval = config.interval |
449 | 484 | self.enable_first_step_checkpoint = config.enable_first_step_checkpoint |
450 | 485 |
|
@@ -600,41 +635,63 @@ def dcp_save( |
600 | 635 | def dcp_load( |
601 | 636 | self, |
602 | 637 | state_dict: dict[str, Any], |
603 | | - checkpoint_id: str, |
| 638 | + checkpoint_id: str | list[str], |
604 | 639 | from_hf: bool, |
605 | 640 | from_quantized: bool, |
606 | 641 | ) -> None: |
607 | | - """Load the checkpoint with dcp. |
| 642 | + """Load the checkpoint(s) with dcp. |
| 643 | +
|
608 | 644 | Args: |
609 | 645 | state_dict (dict): The state dict to load. |
610 | | - checkpoint_id (str): The checkpoint id to load. |
611 | | - from_hf (bool): Whether to load from HuggingFace checkpoint with |
612 | | - its own model definition and safetensors format. |
| 646 | + checkpoint_id (str | list[str]): The checkpoint id(s) to load. |
| 647 | + The first checkpoint is treated as the primary checkpoint. |
| 648 | + Additional checkpoints are always in DCP format. |
| 649 | + from_hf (bool): Whether to load the primary checkpoint from |
| 650 | + HuggingFace safetensors format. |
| 651 | + from_quantized (bool): Whether the HuggingFace checkpoint is quantized. |
613 | 652 | """ |
| 653 | + checkpoint_ids = ( |
| 654 | + [checkpoint_id] if isinstance(checkpoint_id, str) else checkpoint_id |
| 655 | + ) |
| 656 | + # planner = ( |
| 657 | + # DefaultLoadPlanner(allow_partial_load=True) |
| 658 | + # if len(checkpoint_ids) > 1 |
| 659 | + # else DefaultLoadPlanner() |
| 660 | + # ) |
| 661 | + planner = DefaultLoadPlanner(allow_partial_load=True) |
614 | 662 |
|
615 | | - if from_hf: |
616 | | - assert ( |
617 | | - self.sd_adapter is not None |
618 | | - ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." |
619 | | - hf_state_dict = self.sd_adapter.to_hf(state_dict) |
620 | | - hf_storage_reader = self.sd_adapter.get_hf_storage_reader( |
621 | | - checkpoint_id, from_quantized |
622 | | - ) |
623 | | - |
624 | | - dcp.load( |
625 | | - hf_state_dict, |
626 | | - storage_reader=hf_storage_reader, |
627 | | - ) |
| 663 | + for i, cid in enumerate(checkpoint_ids): |
| 664 | + is_primary = i == 0 |
628 | 665 |
|
629 | | - state_dict = self.sd_adapter.from_hf(hf_state_dict) |
630 | | - self.states[MODEL].load_state_dict(state_dict) |
631 | | - else: |
632 | | - dcp.load(state_dict, checkpoint_id=checkpoint_id) |
633 | | - |
634 | | - # TODO: Since we flatten the model states in state_dict, we need to |
635 | | - # manually call load_state_dict() for the model. Need to fix this. |
636 | | - if MODEL in self.states: |
637 | | - self.states[MODEL].load_state_dict(state_dict) |
| 666 | + if is_primary: |
| 667 | + if from_hf: |
| 668 | + # HF format: model only, training states from additional checkpoints |
| 669 | + assert ( |
| 670 | + self.sd_adapter is not None |
| 671 | + ), "Trying to load HF safetensors but sd_adapter is not provided." |
| 672 | + hf_state_dict = self.sd_adapter.to_hf( |
| 673 | + self.states[MODEL].base_state_dict() |
| 674 | + ) |
| 675 | + hf_storage_reader = self.sd_adapter.get_hf_storage_reader( |
| 676 | + cid, from_quantized |
| 677 | + ) |
| 678 | + dcp.load( |
| 679 | + hf_state_dict, |
| 680 | + storage_reader=hf_storage_reader, |
| 681 | + planner=planner, |
| 682 | + ) |
| 683 | + converted_sd = self.sd_adapter.from_hf(hf_state_dict) |
| 684 | + if MODEL in self.states: |
| 685 | + self.states[MODEL].load_state_dict(converted_sd) |
| 686 | + else: |
| 687 | + dcp.load(state_dict, checkpoint_id=cid, planner=planner) |
| 688 | + if MODEL in self.states: |
| 689 | + self.states[MODEL].load_state_dict(state_dict) |
| 690 | + else: |
| 691 | + # Additional checkpoints: always DCP format, load all available states |
| 692 | + dcp.load(state_dict, checkpoint_id=cid, planner=planner) |
| 693 | + if MODEL in self.states: |
| 694 | + self.states[MODEL].load_state_dict(state_dict) |
638 | 695 |
|
639 | 696 | @torch.no_grad() |
640 | 697 | def save(self, curr_step: int, last_step: bool = False) -> None: |
@@ -737,6 +794,12 @@ def load(self, step: int = -1) -> bool: |
737 | 794 | if not self.enable: |
738 | 795 | return False |
739 | 796 |
|
| 797 | + for path in self.additional_load_paths: |
| 798 | + if not os.path.isdir(path): |
| 799 | + raise ValueError( |
| 800 | + f"checkpoint.additional_load_paths contains invalid path: {path}" |
| 801 | + ) |
| 802 | + |
740 | 803 | model_only = False |
741 | 804 | from_hf = False |
742 | 805 | from_quantized = False |
@@ -808,7 +871,7 @@ def load(self, step: int = -1) -> bool: |
808 | 871 | states = self._states_to_load(model_only) |
809 | 872 | self.dcp_load( |
810 | 873 | states, |
811 | | - checkpoint_id=checkpoint_id, |
| 874 | + checkpoint_id=[checkpoint_id] + self.additional_load_paths, |
812 | 875 | from_hf=from_hf, |
813 | 876 | from_quantized=from_quantized, |
814 | 877 | ) |
@@ -947,7 +1010,7 @@ def _save_last_step(self, curr_step: int) -> None: |
947 | 1010 | # is not the same as the export dtype at the end of the training. |
948 | 1011 |
|
949 | 1012 | if self.last_save_model_only: |
950 | | - states = self.states[MODEL].state_dict() |
| 1013 | + states = self.states[MODEL].state_dict_to_save() |
951 | 1014 |
|
952 | 1015 | if self.export_dtype != torch.float32: |
953 | 1016 | states = {k: v.to(self.export_dtype) for k, v in states.items()} |
|
0 commit comments