Skip to content

Commit 9855733

Browse files
generatedunixname89002005232357facebook-github-bot
authored andcommitted
Revert D72410003: Multisect successfully blamed "D72410003: [TorchRec] Update critical path definition" for one test failure
Summary: This diff reverts D72410003 D72410003: [TorchRec] Update critical path definition by micrain causes the following test failure: Tests affected: - [cogwheel:cogwheel_nex_task_and_notebook_creation_test#test_task_creation](https://www.internalfb.com/intern/test/562950141228121/) Here's the Multisect link: https://www.internalfb.com/multisect/25924686 Here are the tasks that are relevant to this breakage: T191381105: 100+ tests, one build rule, one sandcastle job, one CI signal unhealthy for model_understanding_iroc The backout may land if someone accepts it. If this diff has been generated in error, you can Commandeer and Abandon it. Reviewed By: micrain Differential Revision: D72864493 fbshipit-source-id: 67b7a69641a36c803531296301fa1a66718fca68
1 parent 2dc7dc4 commit 9855733

File tree

2 files changed

+2
-96
lines changed

2 files changed

+2
-96
lines changed

torchrec/distributed/planner/stats.py

Lines changed: 2 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
Union,
2626
)
2727

28-
import pandas as pd
2928
from torch import nn
3029

3130
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
@@ -37,7 +36,6 @@
3736
InferenceStorageReservation,
3837
)
3938
from torchrec.distributed.planner.types import (
40-
CriticalPathEstimate,
4139
ParameterConstraints,
4240
Perf,
4341
ShardingOption,
@@ -321,7 +319,7 @@ def log(
321319
)
322320

323321
# Max perf and HBM to help root cause imbalance
324-
self._log_max_perf_and_max_hbm(perf, used_hbm, best_plan)
322+
self._log_max_perf_and_max_hbm(perf, used_hbm)
325323
self._log_storage_reservation_stats(
326324
storage_reservation,
327325
topology,
@@ -447,9 +445,7 @@ def _log_plan_imbalance_stats(
447445
f"# {'Imbalance stats range 0-1, higher means more imbalanced' : <{self._width-3}}#"
448446
)
449447

450-
def _log_max_perf_and_max_hbm(
451-
self, perfs: List[Perf], used_hbm: List[int], best_plan: List[ShardingOption]
452-
) -> None:
448+
def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> None:
453449
total_perfs = [perf.total for perf in perfs]
454450

455451
max_total_perf_text = f"Longest Critical Path (Maximum of Total Perf): {_generate_max_text(total_perfs)}"
@@ -484,8 +480,6 @@ def _log_max_perf_and_max_hbm(
484480
)
485481
sum_of_maxima_text = f"Sum of Maxima: {round(sum_of_maxima, 3)} ms"
486482

487-
critical_path_estimate = _calculate_critical_path(best_plan)
488-
489483
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
490484
self._stats_table.append(f"# {max_total_perf_text : <{self._width-3}}#")
491485
self._stats_table.append(f"# {mean_total_perf_text : <{self._width-3}}#")
@@ -518,15 +512,6 @@ def _log_max_perf_and_max_hbm(
518512
self._stats_table.append(
519513
f"# {'High Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_high) : <{self._width-3}}#"
520514
)
521-
self._stats_table.append(
522-
f"# {'Critical Path (comms): '+str(round(critical_path_estimate.comms_estimate, 3)) : <{self._width-3}}#"
523-
)
524-
self._stats_table.append(
525-
f"# {'Critical Path (compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#"
526-
)
527-
self._stats_table.append(
528-
f"# {'Critical Path (comms + compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#"
529-
)
530515

531516
max_used_hbm = max(used_hbm)
532517
mean_used_hbm = statistics.mean(used_hbm)
@@ -1067,76 +1052,6 @@ def _reduce_int_list(input_list: List[int]) -> str:
10671052
return ", ".join(reduced)
10681053

10691054

1070-
def _calculate_critical_path(best_plan: List[ShardingOption]) -> CriticalPathEstimate:
1071-
"""
1072-
Calculates the critical path of the sharding plan. Makes the following assumptions:
1073-
1074-
1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp.
1075-
2. There are additional synchronization points during communication (both fwd & bwd) for each module <> sharding type combination.
1076-
i. Communication operations for each shard from the same module <> sharding type group are executed sequentially.
1077-
ii. Ranks need to synchronize before they can begin the communication operation for the next module <> sharding type group.
1078-
3. There are additional synchronization points during computation (both fwd & bwd) at the rank level.
1079-
i. Computation operations for each shard from the same module are executed sequentially.
1080-
ii. Ranks need to synchronize before they can begin the next set of events.
1081-
"""
1082-
1083-
perf_data = defaultdict(float)
1084-
for so in best_plan:
1085-
module = so.module
1086-
sharding_type = so.sharding_type
1087-
ranks = sorted([cast(int, shard.rank) for shard in so.shards])
1088-
shard_perfs = [cast(Perf, shard.perf) for shard in so.shards]
1089-
perf_breakdowns = [
1090-
{
1091-
"fwd_compute": perf.fwd_compute,
1092-
"fwd_comms": perf.fwd_comms,
1093-
"bwd_compute": perf.bwd_compute,
1094-
"bwd_comms": perf.bwd_comms,
1095-
"prefetch_compute": perf.prefetch_compute,
1096-
}
1097-
for perf in shard_perfs
1098-
]
1099-
1100-
for rank, perf_breakdown in zip(ranks, perf_breakdowns):
1101-
for perf_type in perf_breakdown:
1102-
perf_data[
1103-
(
1104-
rank,
1105-
module,
1106-
sharding_type,
1107-
perf_type.split("_")[0], # fwd or bwd
1108-
perf_type.split("_")[1], # compute or comms
1109-
)
1110-
] += perf_breakdown[perf_type]
1111-
perf_df = pd.DataFrame.from_dict(perf_data, orient="index", columns=["perf"])
1112-
perf_df.index = pd.MultiIndex.from_tuples(
1113-
perf_df.index,
1114-
names=["rank", "module", "sharding_type", "direction", "perf_type"],
1115-
)
1116-
1117-
comms_estimate = (
1118-
perf_df.xs("comms", level="perf_type")
1119-
.groupby(["rank", "module", "sharding_type", "direction"])
1120-
.sum()
1121-
.groupby(["module", "sharding_type", "direction"])
1122-
.max()
1123-
.sum()
1124-
.item()
1125-
)
1126-
1127-
comp_estimate = (
1128-
perf_df.xs("compute", level="perf_type")
1129-
.groupby(["rank", "direction"])
1130-
.sum()
1131-
.groupby(["direction"])
1132-
.max()
1133-
.sum()
1134-
.item()
1135-
)
1136-
1137-
return CriticalPathEstimate(comms_estimate, comp_estimate)
1138-
1139-
11401055
class NoopEmbeddingStats(Stats):
11411056
"""
11421057
Noop Stats for a sharding planner execution.

torchrec/distributed/planner/types.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -843,12 +843,3 @@ def log(
843843
See class description
844844
"""
845845
...
846-
847-
848-
@dataclass
849-
class CriticalPathEstimate:
850-
comms_estimate: float
851-
comp_estimate: float
852-
853-
def total(self) -> float:
854-
return self.comms_estimate + self.comp_estimate

0 commit comments

Comments
 (0)