|
4 | 4 | import itertools |
5 | 5 | import operator |
6 | 6 | import time |
| 7 | +from collections import defaultdict |
7 | 8 | from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union |
8 | 9 |
|
9 | 10 | import torch |
@@ -480,10 +481,16 @@ def is_similar(ca: torch.fx.node.Node, aot: torch.fx.node.Node): |
480 | 481 | and len(ca.all_input_nodes) == len(aot.all_input_nodes) |
481 | 482 | ) |
482 | 483 |
|
| 484 | + # number of times we saw this AOT backward graph, used to dedup reused graphs |
| 485 | + aot_id_counter: Dict[int, int] = defaultdict(int) |
483 | 486 | for nodecall_index, info in self.aot_graph_infos.items(): |
484 | 487 | ca_node_start_idx = info["ca_node_start_idx"] |
485 | 488 | aot_id = info["aot_id"] |
| 489 | + aot_id_postfix = "" |
486 | 490 | aot_graph = info["aot_gm"].graph |
| 491 | + if aot_id_counter[aot_id]: |
| 492 | + aot_id_postfix = f"_{aot_id_counter[aot_id]}" |
| 493 | + aot_id_counter[aot_id] += 1 |
487 | 494 |
|
488 | 495 | # 1. Find the first op from user code in the AOT graph |
489 | 496 | aot_it = iter(aot_graph.nodes) |
@@ -520,9 +527,11 @@ def is_similar(ca: torch.fx.node.Node, aot: torch.fx.node.Node): |
520 | 527 | # So any deviation is an error |
521 | 528 | raise StopIteration |
522 | 529 |
|
523 | | - ca_node.name = f"aot{aot_id}_{aot_node.name}" |
| 530 | + ca_node.name = f"aot{aot_id}{aot_id_postfix}_{aot_node.name}" |
524 | 531 | for i, inp in enumerate(aot_node.all_input_nodes): |
525 | | - ca_node.all_input_nodes[i].name = f"aot{aot_id}_{inp.name}" |
| 532 | + ca_node.all_input_nodes[ |
| 533 | + i |
| 534 | + ].name = f"aot{aot_id}{aot_id_postfix}_{inp.name}" |
526 | 535 |
|
527 | 536 | aot_node = next(aot_it) |
528 | 537 | ca_node = next(ca_it) |
|
0 commit comments