Skip to content

Commit b6fb135

Browse files
janselpytorchmergebot
authored andcommitted
[inductor] Simplify remove_kernel_local_buffers (pytorch#139452)
I plan to reuse `can_buffer_be_removed_through_fusion` in some heuristics. Pull Request resolved: pytorch#139452 Approved by: https://github.com/shunting314 ghstack dependencies: pytorch#139364, pytorch#139365, pytorch#139370
1 parent 3d633f1 commit b6fb135

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

torch/_inductor/codegen/common.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2528,27 +2528,16 @@ def remove_kernel_local_buffers(self) -> None:
25282528
for buf in self.store_buffer_names
25292529
if buf in scheduler.name_to_buf
25302530
)
2531-
names_to_remove = []
2532-
for out_buf in self.store_buffer_names:
2533-
if out_buf not in scheduler.name_to_buf:
2534-
# Aux buffers created during kernel codegen
2535-
names_to_remove.append(out_buf)
2536-
continue
2537-
users = scheduler.name_to_buf[out_buf].users
2538-
assert users is not None
2539-
users = OrderedSet(user.get_name() for user in users if not user.is_weak)
2540-
if users.issubset(fused_node_names):
2541-
names_to_remove.append(out_buf)
2542-
2543-
def remove_filter(n: str) -> bool:
2544-
return (
2545-
n not in self.must_keep_buffers
2546-
and n not in self.args.input_buffers
2547-
and n not in scheduler.mutation_renames
2548-
and n not in scheduler.mutation_real_name
2549-
)
2550-
2551-
names_to_remove = [*filter(remove_filter, names_to_remove)]
2531+
names_to_remove: OrderedSet[str] = OrderedSet()
2532+
for name in self.store_buffer_names:
2533+
if (
2534+
name not in self.must_keep_buffers
2535+
and name not in self.args.input_buffers
2536+
and scheduler.can_buffer_be_removed_through_fusion(
2537+
name, fused_node_names
2538+
)
2539+
):
2540+
names_to_remove.add(name)
25522541

25532542
for name in names_to_remove:
25542543
if name in self.args.inplace_buffers:

torch/_inductor/scheduler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3399,6 +3399,19 @@ def get_order(n: torch.fx.Node) -> int:
33993399
_, last = max(origins, key=operator.itemgetter(0))
34003400
V.graph.wrapper_code.enter_context(last)
34013401

3402+
def can_buffer_be_removed_through_fusion(
3403+
self, name: str, fused_node_names: OrderedSet[str]
3404+
) -> bool:
3405+
try:
3406+
users = self.name_to_buf[name].users
3407+
except KeyError:
3408+
return False
3409+
return (
3410+
all(user.is_weak or user.get_name() in fused_node_names for user in users)
3411+
and name not in self.mutation_renames
3412+
and name not in self.mutation_real_name
3413+
)
3414+
34023415
def codegen(self) -> None:
34033416
with dynamo_timed("Scheduler.codegen"):
34043417
return self._codegen()

0 commit comments

Comments
 (0)