Skip to content

Commit 73a6a40

Browse files
leslie-fang-intelpytorchmergebot
authored andcommitted
[Inductor][CPP] Fix outer loop fusion buffer removed (pytorch#144243)
**Summary** Fix issue: pytorch#144186. For the test case reported in the issue, we have saw some nodes with `LoopNest` - `LoopNest(loops=[LoopLevel(var=x0, size=8, offset=0, tiled_size=0, steps=1, parallel=0, simd_omp=False, simd_vec=False, collapsed=False, is_reduction=False), LoopLevel(var=x1, size=8, offset=0, tiled_size=0, steps=1, parallel=0, simd_omp=False, simd_vec=False, collapsed=False, is_reduction=True)], kernel=<torch._inductor.codegen.cpp.CppKernelProxy object at 0x7fc724426680>)` - `LoopNest(loops=[LoopLevel(var=x0, size=8, offset=0, tiled_size=0, steps=16, parallel=0, simd_omp=False, simd_vec=True, collapsed=False, is_reduction=False), LoopLevel(var=x1, size=8, offset=0, tiled_size=0, steps=16, parallel=0, simd_omp=False, simd_vec=True, collapsed=False, is_reduction=True)], kernel=<torch._inductor.codegen.cpp.CppKernelProxy object at 0x7fc75c2cae60>)` Although, these 2 `LoopNest` have same `range` and `var`, but different `steps` 1 and 16. So, they will fail to be merged with outer loops. And since when we localize the buffer, we have removed the global buffers. We need to restore the status of `V.graph.removed_buffers` before fallback to codegen without outer loop fusion. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_outer_loop_fusion_buffer_remove ``` Pull Request resolved: pytorch#144243 Approved by: https://github.com/jgong5
1 parent 2f6f135 commit 73a6a40

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2958,6 +2958,17 @@ def fn(x):
29582958
2,
29592959
)
29602960

2961+
def test_outer_loop_fusion_buffer_remove(self):
2962+
# https://github.com/pytorch/pytorch/issues/144186
2963+
def fn(x):
2964+
x = x.sum(dim=-1)
2965+
x = torch.softmax(x, -1)
2966+
return x
2967+
2968+
x = torch.randn(8, 8, 2)
2969+
metrics.reset()
2970+
self.common(fn, (x,))
2971+
29612972
@config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False})
29622973
def test_local_buffer_in_outer_loop_fusion(self):
29632974
def fn(x):

torch/_inductor/codegen/cpp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4793,6 +4793,10 @@ def try_share_local_buffer(local_buffer_layout, local_buffers):
47934793
if not node.check_outer_fusion_loop_level_attr(
47944794
cpp_kernel_proxy_list, node.outer_loop_fusion_depth
47954795
):
4796+
for removed_buffer in scope.removed_buffers:
4797+
# Restore the removed buffers by this context before
4798+
# fallback to codegen without using Local Buffer
4799+
V.graph.removed_buffers.remove(removed_buffer)
47964800
return False
47974801
metrics.cpp_outer_loop_fused_inner_counts.append(
47984802
metrics.CppOuterLoopFusedCount(

torch/_inductor/codegen/cpp_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ def __init__(self, kernel_args: KernelArgs) -> None:
366366
self.global_buffers: Dict[str, ir.Buffer] = {}
367367
# map global buffer name to local buffer
368368
self.global_to_local: Dict[str, ir.Buffer] = {}
369+
# record the global buffers that are removed by this LocalBufferContext
370+
self.removed_buffers: OrderedSet[str] = OrderedSet()
369371

370372
def __enter__(self):
371373
self.exit_stack.__enter__()
@@ -419,7 +421,12 @@ def add_local_buffer(
419421
)
420422
self.global_buffers[global_buffer_name] = global_buffer
421423
self.global_to_local[global_buffer_name] = local_buffer
422-
V.graph.removed_buffers.add(global_buffer_name)
424+
if global_buffer_name not in V.graph.removed_buffers:
425+
# Record the global buffers that are removed by this LocalBufferContext
426+
# since which may need to restore. Refer to issue:
427+
# https://github.com/pytorch/pytorch/issues/144186
428+
self.removed_buffers.add(global_buffer_name)
429+
V.graph.removed_buffers.add(global_buffer_name)
423430

424431
def localize_function(
425432
self,

0 commit comments

Comments
 (0)