Skip to content

Commit 28a2ab6

Browse files
jamesjwupytorchmergebot
authored andcommitted
Clear CompiledTritonKernel cache after each inductor compile (pytorch#146925)
Fix a bug introduced by D69123174: because triton kernels now are returned directly by the worker, each future created by the triton kernel should only be used once per compile. Otherwise, a long running process that does something like in : ``` compiled_1 = torch.compile("max-autotune", fullgraph=True)(fn) # run compiled_1 out_compiled = compiled_1 compiled_2 = torch.compile("max-autotune", fullgraph=True)(fn2) ``` Where fn1 and fn2 are very similar (i.e. would generate the same triton kernel source code) would result in us using the launcher for the first autotuning run, and setting the launcher to None after running, and then using the same future/kernel again without regenerating the launcher. Found this bug testing internal inference models. This does not remove the caching support for @eellison's caching for prologue benchmarking, because that happens under the same compile: pytorch#143408 Differential Revision: [D69476856](https://our.internmc.facebook.com/intern/diff/D69476856/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D69476856/)! Pull Request resolved: pytorch#146925 Approved by: https://github.com/laithsakka, https://github.com/jansel ghstack dependencies: pytorch#146417
1 parent 0acbf80 commit 28a2ab6

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torch/_inductor/compile_fx.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,11 +821,14 @@ def _compile_fx_inner(
821821
},
822822
payload_fn=lambda: json.dumps(cache_info),
823823
)
824-
825824
compiled_graph.post_compile(example_inputs, cudagraphs, constants)
826825

827826
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
828827

828+
# Clear Compiled Triton Kernels per inductor compile, as the future objects
829+
# may not be valid for use after they are run/autotuned
830+
torch._inductor.async_compile.CompiledTritonKernels.cache_clear()
831+
829832
_step_logger()(
830833
logging.INFO,
831834
"torchinductor done compiling "

0 commit comments

Comments
 (0)