|
25 | 25 | Union,
|
26 | 26 | )
|
27 | 27 |
|
28 |
| -import pandas as pd |
29 | 28 | from torch import nn
|
30 | 29 |
|
31 | 30 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel
|
|
37 | 36 | InferenceStorageReservation,
|
38 | 37 | )
|
39 | 38 | from torchrec.distributed.planner.types import (
|
40 |
| - CriticalPathEstimate, |
41 | 39 | ParameterConstraints,
|
42 | 40 | Perf,
|
43 | 41 | ShardingOption,
|
@@ -321,7 +319,7 @@ def log(
|
321 | 319 | )
|
322 | 320 |
|
323 | 321 | # Max perf and HBM to help root cause imbalance
|
324 |
| - self._log_max_perf_and_max_hbm(perf, used_hbm, best_plan) |
| 322 | + self._log_max_perf_and_max_hbm(perf, used_hbm) |
325 | 323 | self._log_storage_reservation_stats(
|
326 | 324 | storage_reservation,
|
327 | 325 | topology,
|
@@ -447,9 +445,7 @@ def _log_plan_imbalance_stats(
|
447 | 445 | f"# {'Imbalance stats range 0-1, higher means more imbalanced' : <{self._width-3}}#"
|
448 | 446 | )
|
449 | 447 |
|
450 |
| - def _log_max_perf_and_max_hbm( |
451 |
| - self, perfs: List[Perf], used_hbm: List[int], best_plan: List[ShardingOption] |
452 |
| - ) -> None: |
| 448 | + def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> None: |
453 | 449 | total_perfs = [perf.total for perf in perfs]
|
454 | 450 |
|
455 | 451 | max_total_perf_text = f"Longest Critical Path (Maximum of Total Perf): {_generate_max_text(total_perfs)}"
|
@@ -484,8 +480,6 @@ def _log_max_perf_and_max_hbm(
|
484 | 480 | )
|
485 | 481 | sum_of_maxima_text = f"Sum of Maxima: {round(sum_of_maxima, 3)} ms"
|
486 | 482 |
|
487 |
| - critical_path_estimate = _calculate_critical_path(best_plan) |
488 |
| - |
489 | 483 | self._stats_table.append(f"#{'' : ^{self._width-2}}#")
|
490 | 484 | self._stats_table.append(f"# {max_total_perf_text : <{self._width-3}}#")
|
491 | 485 | self._stats_table.append(f"# {mean_total_perf_text : <{self._width-3}}#")
|
@@ -518,15 +512,6 @@ def _log_max_perf_and_max_hbm(
|
518 | 512 | self._stats_table.append(
|
519 | 513 | f"# {'High Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_high) : <{self._width-3}}#"
|
520 | 514 | )
|
521 |
| - self._stats_table.append( |
522 |
| - f"# {'Critical Path (comms): '+str(round(critical_path_estimate.comms_estimate, 3)) : <{self._width-3}}#" |
523 |
| - ) |
524 |
| - self._stats_table.append( |
525 |
| - f"# {'Critical Path (compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#" |
526 |
| - ) |
527 |
| - self._stats_table.append( |
528 |
| - f"# {'Critical Path (comms + compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#" |
529 |
| - ) |
530 | 515 |
|
531 | 516 | max_used_hbm = max(used_hbm)
|
532 | 517 | mean_used_hbm = statistics.mean(used_hbm)
|
@@ -1067,76 +1052,6 @@ def _reduce_int_list(input_list: List[int]) -> str:
|
1067 | 1052 | return ", ".join(reduced)
|
1068 | 1053 |
|
1069 | 1054 |
|
1070 |
| -def _calculate_critical_path(best_plan: List[ShardingOption]) -> CriticalPathEstimate: |
1071 |
| - """ |
1072 |
| - Calculates the critical path of the sharding plan. Makes the following assumptions: |
1073 |
| -
|
1074 |
| - 1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp. |
1075 |
| - 2. There are additional synchronization points during communication (both fwd & bwd) for each module <> sharding type combination. |
1076 |
| - i. Communication operations for each shard from the same module <> sharding type group are executed sequentially. |
1077 |
| - ii. Ranks need to synchronize before they can begin the communication operation for the next module <> sharding type group. |
1078 |
| - 3. There are additional synchronization points during computation (both fwd & bwd) at the rank level. |
1079 |
| - i. Computation operations for each shard from the same module are executed sequentially. |
1080 |
| - ii. Ranks need to synchronize before they can begin the next set of events. |
1081 |
| - """ |
1082 |
| - |
1083 |
| - perf_data = defaultdict(float) |
1084 |
| - for so in best_plan: |
1085 |
| - module = so.module |
1086 |
| - sharding_type = so.sharding_type |
1087 |
| - ranks = sorted([cast(int, shard.rank) for shard in so.shards]) |
1088 |
| - shard_perfs = [cast(Perf, shard.perf) for shard in so.shards] |
1089 |
| - perf_breakdowns = [ |
1090 |
| - { |
1091 |
| - "fwd_compute": perf.fwd_compute, |
1092 |
| - "fwd_comms": perf.fwd_comms, |
1093 |
| - "bwd_compute": perf.bwd_compute, |
1094 |
| - "bwd_comms": perf.bwd_comms, |
1095 |
| - "prefetch_compute": perf.prefetch_compute, |
1096 |
| - } |
1097 |
| - for perf in shard_perfs |
1098 |
| - ] |
1099 |
| - |
1100 |
| - for rank, perf_breakdown in zip(ranks, perf_breakdowns): |
1101 |
| - for perf_type in perf_breakdown: |
1102 |
| - perf_data[ |
1103 |
| - ( |
1104 |
| - rank, |
1105 |
| - module, |
1106 |
| - sharding_type, |
1107 |
| - perf_type.split("_")[0], # fwd or bwd |
1108 |
| - perf_type.split("_")[1], # compute or comms |
1109 |
| - ) |
1110 |
| - ] += perf_breakdown[perf_type] |
1111 |
| - perf_df = pd.DataFrame.from_dict(perf_data, orient="index", columns=["perf"]) |
1112 |
| - perf_df.index = pd.MultiIndex.from_tuples( |
1113 |
| - perf_df.index, |
1114 |
| - names=["rank", "module", "sharding_type", "direction", "perf_type"], |
1115 |
| - ) |
1116 |
| - |
1117 |
| - comms_estimate = ( |
1118 |
| - perf_df.xs("comms", level="perf_type") |
1119 |
| - .groupby(["rank", "module", "sharding_type", "direction"]) |
1120 |
| - .sum() |
1121 |
| - .groupby(["module", "sharding_type", "direction"]) |
1122 |
| - .max() |
1123 |
| - .sum() |
1124 |
| - .item() |
1125 |
| - ) |
1126 |
| - |
1127 |
| - comp_estimate = ( |
1128 |
| - perf_df.xs("compute", level="perf_type") |
1129 |
| - .groupby(["rank", "direction"]) |
1130 |
| - .sum() |
1131 |
| - .groupby(["direction"]) |
1132 |
| - .max() |
1133 |
| - .sum() |
1134 |
| - .item() |
1135 |
| - ) |
1136 |
| - |
1137 |
| - return CriticalPathEstimate(comms_estimate, comp_estimate) |
1138 |
| - |
1139 |
| - |
1140 | 1055 | class NoopEmbeddingStats(Stats):
|
1141 | 1056 | """
|
1142 | 1057 | Noop Stats for a sharding planner execution.
|
|
0 commit comments