|
66 | 66 | DataLoadingThread,
|
67 | 67 | use_context_for_postprocs,
|
68 | 68 | )
|
69 |
| -from torchrec.distributed.types import Awaitable |
| 69 | +from torchrec.distributed.types import Awaitable, ParameterSharding |
70 | 70 | from torchrec.pt2.checks import is_torchdynamo_compiling
|
71 | 71 | from torchrec.pt2.utils import default_pipeline_input_transformer
|
72 | 72 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
@@ -696,6 +696,87 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
696 | 696 | self.dequeue_batch()
|
697 | 697 | return output
|
698 | 698 |
|
| 699 | + def progress_with_reshard( |
| 700 | + self, |
| 701 | + dataloader_iter: Iterator[In], |
| 702 | + reshard_params: Dict[str, ParameterSharding], |
| 703 | + sharded_module_fqn: Optional[str] = None, |
| 704 | + ) -> Out: |
| 705 | + """ |
| 706 | + As resharding will affect the tensor placements. Will temporarily undo pipeline overlap |
| 707 | + """ |
| 708 | + # Assume pipeline batches are not empty: |
| 709 | + # # attach the model just in case the user forgets to call it, especially when the user |
| 710 | + # # pauses the pipeline.progress and detach the model for other purpose. |
| 711 | + # if not self._model_attached: |
| 712 | + # self.attach(self._model) |
| 713 | + |
| 714 | + # # fill the pipeline is only needed for the beginning when the pipeline (batches) is empty |
| 715 | + # self.fill_pipeline(dataloader_iter) |
| 716 | + |
| 717 | + # Assume not last batch |
| 718 | + # # here is the expected stop after exhausting all batches |
| 719 | + if not self.batches: |
| 720 | + raise StopIteration |
| 721 | + # import fbvscode |
| 722 | + |
| 723 | + # fbvscode.set_trace() |
| 724 | + # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) |
| 725 | + self._set_module_context(self.contexts[0]) |
| 726 | + |
| 727 | + if self._model.training: |
| 728 | + with record_function("## zero_grad ##"): |
| 729 | + self._optimizer.zero_grad() |
| 730 | + |
| 731 | + # wait for batches[0] being available on device, this should always be completed since |
| 732 | + # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check |
| 733 | + self._wait_for_batch() |
| 734 | + |
| 735 | + # Assume _enqueue_batch_after_forward is False |
| 736 | + # if not self._enqueue_batch_after_forward: |
| 737 | + # # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here |
| 738 | + # self.enqueue_batch(dataloader_iter) |
| 739 | + |
| 740 | + # But reshard after this. |
| 741 | + # forward |
| 742 | + with record_function("## forward ##"): |
| 743 | + losses, output = self._model_fwd(self.batches[0]) |
| 744 | + |
| 745 | + # if self._enqueue_batch_after_forward: |
| 746 | + # # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. |
| 747 | + # # Start this step after the forward of batch i, so that the H2D copy doesn't compete |
| 748 | + # # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING. |
| 749 | + # self.enqueue_batch(dataloader_iter) |
| 750 | + |
| 751 | + if self._model.training: |
| 752 | + # backward |
| 753 | + self._backward(losses) |
| 754 | + |
| 755 | + self.sync_embeddings( |
| 756 | + self._model, |
| 757 | + self._dmp_collection_sync_interval_batches, |
| 758 | + self.contexts[0], |
| 759 | + ) |
| 760 | + |
| 761 | + # update |
| 762 | + with record_function("## optimizer ##"): |
| 763 | + self._optimizer.step() |
| 764 | + |
| 765 | + # Reshard |
| 766 | + self._model.reshard( # pyre-ignore |
| 767 | + sharded_module_fqn=sharded_module_fqn, |
| 768 | + changed_shard_to_params=reshard_params, |
| 769 | + ) |
| 770 | + |
| 771 | + # Need to reshard before this. |
| 772 | + if len(self.batches) >= 2: |
| 773 | + # invoke splits all_to_all comms (first part of input_dist) |
| 774 | + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) |
| 775 | + # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) |
| 776 | + self.wait_sparse_data_dist(self.contexts[1]) |
| 777 | + self.dequeue_batch() |
| 778 | + return output |
| 779 | + |
699 | 780 | def _create_context(self) -> TrainPipelineContext:
|
700 | 781 | context = self._context_type(index=self._next_index, version=1)
|
701 | 782 | self._next_index += 1
|
|
0 commit comments