Skip to content

Commit 142322b

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Pipeline Integration
Differential Revision: D78191049
1 parent 01f8654 commit 142322b

File tree

3 files changed

+116
-3
lines changed

3 files changed

+116
-3
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
See benchmark_pipeline_utils.py for step-by-step instructions.
2121
"""
2222

23+
import copy
2324
from dataclasses import dataclass, field
2425
from typing import Dict, List, Optional, Type, Union
2526

@@ -349,9 +350,40 @@ def _func_to_benchmark(
349350
pipeline: TrainPipeline,
350351
) -> None:
351352
dataloader = iter(bench_inputs)
353+
i = 0
352354
while True:
353355
try:
354-
pipeline.progress(dataloader)
356+
# import fbvscode
357+
358+
# fbvscode.set_trace()
359+
if i == 3:
360+
# Extract existing sharding plan
361+
existing_sharding_plan = pipeline._model.module.sparse.ebc.module_sharding_plan # pyre-ignore
362+
fqn_to_local_shards = "sparse.ebc"
363+
# Modify existing sharding plan - Hard code
364+
sharding_param = copy.deepcopy(
365+
existing_sharding_plan["table_0"]
366+
)
367+
new_device = 1 if sharding_param.ranks[0] == 0 else 0
368+
sharding_param.ranks = [new_device]
369+
sharding_param.sharding_spec.shards[0].placement = (
370+
torch.distributed._remote_device(
371+
f"rank:{new_device}/cuda:{new_device}"
372+
)
373+
)
374+
375+
new_sharding_plan = {}
376+
new_sharding_plan["table_0"] = sharding_param
377+
# Reshard
378+
pipeline.progress_with_reshard( # pyre-ignore
379+
dataloader_iter=dataloader,
380+
reshard_params=new_sharding_plan,
381+
sharded_module_fqn=fqn_to_local_shards,
382+
)
383+
i += 1
384+
else:
385+
pipeline.progress(dataloader)
386+
i += 1
355387
except StopIteration:
356388
break
357389

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1764,7 +1764,7 @@ def update_shards(
17641764
# Modifies new_opt_state in place and returns it
17651765
optimizer_state = update_optimizer_state_post_resharding(
17661766
old_opt_state=old_optimizer_state, # pyre-ignore
1767-
new_opt_state=copy.deepcopy(self._optim.state_dict()),
1767+
new_opt_state=self._optim.state_dict(), # undo deep copy?
17681768
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
17691769
output_tensor=local_optimizer_tensors,
17701770
max_dim_0=max_dim_0,

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
DataLoadingThread,
6767
use_context_for_postprocs,
6868
)
69-
from torchrec.distributed.types import Awaitable
69+
from torchrec.distributed.types import Awaitable, ParameterSharding
7070
from torchrec.pt2.checks import is_torchdynamo_compiling
7171
from torchrec.pt2.utils import default_pipeline_input_transformer
7272
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -696,6 +696,87 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696696
self.dequeue_batch()
697697
return output
698698

699+
def progress_with_reshard(
700+
self,
701+
dataloader_iter: Iterator[In],
702+
reshard_params: Dict[str, ParameterSharding],
703+
sharded_module_fqn: Optional[str] = None,
704+
) -> Out:
705+
"""
706+
As resharding will affect the tensor placements. Will temporarily undo pipeline overlap
707+
"""
708+
# Assume pipeline batches are not empty:
709+
# # attach the model just in case the user forgets to call it, especially when the user
710+
# # pauses the pipeline.progress and detach the model for other purpose.
711+
# if not self._model_attached:
712+
# self.attach(self._model)
713+
714+
# # fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
715+
# self.fill_pipeline(dataloader_iter)
716+
717+
# Assume not last batch
718+
# # here is the expected stop after exhausting all batches
719+
if not self.batches:
720+
raise StopIteration
721+
# import fbvscode
722+
723+
# fbvscode.set_trace()
724+
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
725+
self._set_module_context(self.contexts[0])
726+
727+
if self._model.training:
728+
with record_function("## zero_grad ##"):
729+
self._optimizer.zero_grad()
730+
731+
# wait for batches[0] being available on device, this should always be completed since
732+
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
733+
self._wait_for_batch()
734+
735+
# Assume _enqueue_batch_after_forward is False
736+
# if not self._enqueue_batch_after_forward:
737+
# # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
738+
# self.enqueue_batch(dataloader_iter)
739+
740+
# But reshard after this.
741+
# forward
742+
with record_function("## forward ##"):
743+
losses, output = self._model_fwd(self.batches[0])
744+
745+
# if self._enqueue_batch_after_forward:
746+
# # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
747+
# # Start this step after the forward of batch i, so that the H2D copy doesn't compete
748+
# # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
749+
# self.enqueue_batch(dataloader_iter)
750+
751+
if self._model.training:
752+
# backward
753+
self._backward(losses)
754+
755+
self.sync_embeddings(
756+
self._model,
757+
self._dmp_collection_sync_interval_batches,
758+
self.contexts[0],
759+
)
760+
761+
# update
762+
with record_function("## optimizer ##"):
763+
self._optimizer.step()
764+
765+
# Reshard
766+
self._model.reshard( # pyre-ignore
767+
sharded_module_fqn=sharded_module_fqn,
768+
changed_shard_to_params=reshard_params,
769+
)
770+
771+
# Need to reshard before this.
772+
if len(self.batches) >= 2:
773+
# invoke splits all_to_all comms (first part of input_dist)
774+
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
775+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
776+
self.wait_sparse_data_dist(self.contexts[1])
777+
self.dequeue_batch()
778+
return output
779+
699780
def _create_context(self) -> TrainPipelineContext:
700781
context = self._context_type(index=self._next_index, version=1)
701782
self._next_index += 1

0 commit comments

Comments
 (0)