Skip to content

Commit 2265c2d

Browse files
ezyangpytorchmergebot
authored andcommitted
Add pytorch.wait_counter.actual_codegen_and_compile WaitCounter (pytorch#138010)
The current pytorch.wait_counter.codegen_and_compile scopes over cache hit/miss, so it doesn't accurately say if you're actually spending time doing Inductor compile or not. This counter /only/ is triggered when we're actually about to spend time in Inductor. It covers Inductor lowering, codegen as well as Triton compilation. It does NOT cover Triton compilation that occurs when you cache hit. Some more bikeshedding may be needed. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#138010 Approved by: https://github.com/markkm
1 parent 46132dc commit 2265c2d

File tree

1 file changed

+73
-68
lines changed

1 file changed

+73
-68
lines changed

torch/_inductor/compile_fx.py

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -638,81 +638,86 @@ def codegen_and_compile(
638638
This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting
639639
compiled fx graph. The metadata is saved to FXGraphCache.
640640
"""
641-
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
642-
if isinstance(compiled_graph, str):
643-
# We only return a string in aot mode, in which case we don't
644-
# need to do any post-compilation steps: we just return the string,
645-
# which is the filename of the compiled code.
646-
return compiled_graph
647-
cudagraph_info = None
648-
if cudagraphs:
649-
# check cudagraph disabling reasons from inductor lowering
650-
if compiled_graph.disabled_cudagraphs_reason:
651-
if "cuda" in compiled_graph.device_types:
652-
log_cudagraph_skip_and_bump_counter(
653-
f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
654-
)
641+
with _WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard():
642+
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
643+
if isinstance(compiled_graph, str):
644+
# We only return a string in aot mode, in which case we don't
645+
# need to do any post-compilation steps: we just return the string,
646+
# which is the filename of the compiled code.
647+
return compiled_graph
648+
cudagraph_info = None
649+
if cudagraphs:
650+
# check cudagraph disabling reasons from inductor lowering
651+
if compiled_graph.disabled_cudagraphs_reason:
652+
if "cuda" in compiled_graph.device_types:
653+
log_cudagraph_skip_and_bump_counter(
654+
f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
655+
)
656+
else:
657+
counters["inductor"]["cudagraph_skips"] += 1
658+
BoxedBool.disable(cudagraphs)
655659
else:
656-
counters["inductor"]["cudagraph_skips"] += 1
657-
BoxedBool.disable(cudagraphs)
658-
else:
659-
complex_memory_overlap_inputs = any(
660-
complex_memory_overlap(t)
661-
for t in example_inputs
662-
if isinstance(t, torch.Tensor)
663-
)
664-
665-
if not config.triton.cudagraph_support_input_mutation:
666-
# Skip supports for cudagraph-managed tensors
667-
from torch._inductor.cudagraph_utils import (
668-
check_for_mutation_ignore_cuda_graph_managed_tensor,
660+
complex_memory_overlap_inputs = any(
661+
complex_memory_overlap(t)
662+
for t in example_inputs
663+
if isinstance(t, torch.Tensor)
669664
)
670665

671-
has_mutation_str = (
672-
check_for_mutation_ignore_cuda_graph_managed_tensor(
673-
gm,
674-
compiled_graph,
675-
static_input_idxs,
666+
if not config.triton.cudagraph_support_input_mutation:
667+
# Skip supports for cudagraph-managed tensors
668+
from torch._inductor.cudagraph_utils import (
669+
check_for_mutation_ignore_cuda_graph_managed_tensor,
676670
)
677-
)
678-
has_mutation = has_mutation_str is not None
679671

680-
if has_mutation:
681-
compiled_graph.disabled_cudagraphs_reason = has_mutation_str
682-
else:
683-
# Check mutation later to support cudagraph-managed tensors
684-
has_mutation = None
685-
686-
cudagraph_tests = [
687-
(not has_mutation, "mutated inputs"),
688-
(not complex_memory_overlap_inputs, "complex memory overlap"),
689-
(
690-
all(
691-
isinstance(t, (torch.Tensor, torch.SymInt))
692-
for t in example_inputs
672+
has_mutation_str = (
673+
check_for_mutation_ignore_cuda_graph_managed_tensor(
674+
gm,
675+
compiled_graph,
676+
static_input_idxs,
677+
)
678+
)
679+
has_mutation = has_mutation_str is not None
680+
681+
if has_mutation:
682+
compiled_graph.disabled_cudagraphs_reason = has_mutation_str
683+
else:
684+
# Check mutation later to support cudagraph-managed tensors
685+
has_mutation = None
686+
687+
cudagraph_tests = [
688+
(not has_mutation, "mutated inputs"),
689+
(not complex_memory_overlap_inputs, "complex memory overlap"),
690+
(
691+
all(
692+
isinstance(t, (torch.Tensor, torch.SymInt))
693+
for t in example_inputs
694+
),
695+
"non-Tensor inputs",
693696
),
694-
"non-Tensor inputs",
695-
),
696-
]
697-
output = output_node(gm)
698-
# output args are tuple of first argument
699-
assert len(output.args) == 1
700-
stack_traces = [
701-
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
702-
for arg in output.args[0]
703-
]
704-
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
705-
placeholders = tuple(get_placeholder_info(gm.graph))
706-
cudagraph_info = CudagraphCachedInfo(
707-
placeholders, stack_traces, cudagraph_fail_reasons
708-
)
697+
]
698+
output = output_node(gm)
699+
# output args are tuple of first argument
700+
assert len(output.args) == 1
701+
stack_traces = [
702+
(
703+
arg.stack_trace
704+
if isinstance(arg, torch.fx.node.Node)
705+
else None
706+
)
707+
for arg in output.args[0]
708+
]
709+
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
710+
placeholders = tuple(get_placeholder_info(gm.graph))
711+
cudagraph_info = CudagraphCachedInfo(
712+
placeholders, stack_traces, cudagraph_fail_reasons
713+
)
709714

710-
compiled_graph.cudagraph_info = cudagraph_info
711-
compiled_graph.inputs_to_check = inputs_to_check
712-
compiled_graph.fx_kwargs = fx_kwargs
713-
# TODO: should this be part of fx_kwargs
714-
compiled_graph.boxed_forward_device_index = boxed_forward_device_index
715-
return compiled_graph
715+
compiled_graph.cudagraph_info = cudagraph_info
716+
compiled_graph.inputs_to_check = inputs_to_check
717+
compiled_graph.fx_kwargs = fx_kwargs
718+
# TODO: should this be part of fx_kwargs
719+
compiled_graph.boxed_forward_device_index = boxed_forward_device_index
720+
return compiled_graph
716721

717722
with _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _:
718723
if (

0 commit comments

Comments
 (0)