@@ -638,81 +638,86 @@ def codegen_and_compile(
638638 This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting
639639 compiled fx graph. The metadata is saved to FXGraphCache.
640640 """
641- compiled_graph = fx_codegen_and_compile (gm , example_inputs , ** fx_kwargs )
642- if isinstance (compiled_graph , str ):
643- # We only return a string in aot mode, in which case we don't
644- # need to do any post-compilation steps: we just return the string,
645- # which is the filename of the compiled code.
646- return compiled_graph
647- cudagraph_info = None
648- if cudagraphs :
649- # check cudagraph disabling reasons from inductor lowering
650- if compiled_graph .disabled_cudagraphs_reason :
651- if "cuda" in compiled_graph .device_types :
652- log_cudagraph_skip_and_bump_counter (
653- f"skipping cudagraphs due to { compiled_graph .disabled_cudagraphs_reason } "
654- )
641+ with _WaitCounter ("pytorch.wait_counter.actual_codegen_and_compile" ).guard ():
642+ compiled_graph = fx_codegen_and_compile (gm , example_inputs , ** fx_kwargs )
643+ if isinstance (compiled_graph , str ):
644+ # We only return a string in aot mode, in which case we don't
645+ # need to do any post-compilation steps: we just return the string,
646+ # which is the filename of the compiled code.
647+ return compiled_graph
648+ cudagraph_info = None
649+ if cudagraphs :
650+ # check cudagraph disabling reasons from inductor lowering
651+ if compiled_graph .disabled_cudagraphs_reason :
652+ if "cuda" in compiled_graph .device_types :
653+ log_cudagraph_skip_and_bump_counter (
654+ f"skipping cudagraphs due to { compiled_graph .disabled_cudagraphs_reason } "
655+ )
656+ else :
657+ counters ["inductor" ]["cudagraph_skips" ] += 1
658+ BoxedBool .disable (cudagraphs )
655659 else :
656- counters ["inductor" ]["cudagraph_skips" ] += 1
657- BoxedBool .disable (cudagraphs )
658- else :
659- complex_memory_overlap_inputs = any (
660- complex_memory_overlap (t )
661- for t in example_inputs
662- if isinstance (t , torch .Tensor )
663- )
664-
665- if not config .triton .cudagraph_support_input_mutation :
666- # Skip supports for cudagraph-managed tensors
667- from torch ._inductor .cudagraph_utils import (
668- check_for_mutation_ignore_cuda_graph_managed_tensor ,
660+ complex_memory_overlap_inputs = any (
661+ complex_memory_overlap (t )
662+ for t in example_inputs
663+ if isinstance (t , torch .Tensor )
669664 )
670665
671- has_mutation_str = (
672- check_for_mutation_ignore_cuda_graph_managed_tensor (
673- gm ,
674- compiled_graph ,
675- static_input_idxs ,
666+ if not config .triton .cudagraph_support_input_mutation :
667+ # Skip supports for cudagraph-managed tensors
668+ from torch ._inductor .cudagraph_utils import (
669+ check_for_mutation_ignore_cuda_graph_managed_tensor ,
676670 )
677- )
678- has_mutation = has_mutation_str is not None
679671
680- if has_mutation :
681- compiled_graph .disabled_cudagraphs_reason = has_mutation_str
682- else :
683- # Check mutation later to support cudagraph-managed tensors
684- has_mutation = None
685-
686- cudagraph_tests = [
687- (not has_mutation , "mutated inputs" ),
688- (not complex_memory_overlap_inputs , "complex memory overlap" ),
689- (
690- all (
691- isinstance (t , (torch .Tensor , torch .SymInt ))
692- for t in example_inputs
672+ has_mutation_str = (
673+ check_for_mutation_ignore_cuda_graph_managed_tensor (
674+ gm ,
675+ compiled_graph ,
676+ static_input_idxs ,
677+ )
678+ )
679+ has_mutation = has_mutation_str is not None
680+
681+ if has_mutation :
682+ compiled_graph .disabled_cudagraphs_reason = has_mutation_str
683+ else :
684+ # Check mutation later to support cudagraph-managed tensors
685+ has_mutation = None
686+
687+ cudagraph_tests = [
688+ (not has_mutation , "mutated inputs" ),
689+ (not complex_memory_overlap_inputs , "complex memory overlap" ),
690+ (
691+ all (
692+ isinstance (t , (torch .Tensor , torch .SymInt ))
693+ for t in example_inputs
694+ ),
695+ "non-Tensor inputs" ,
693696 ),
694- "non-Tensor inputs" ,
695- ),
696- ]
697- output = output_node (gm )
698- # output args are tuple of first argument
699- assert len (output .args ) == 1
700- stack_traces = [
701- (arg .stack_trace if isinstance (arg , torch .fx .node .Node ) else None )
702- for arg in output .args [0 ]
703- ]
704- cudagraph_fail_reasons = [s for b , s in cudagraph_tests if not b ]
705- placeholders = tuple (get_placeholder_info (gm .graph ))
706- cudagraph_info = CudagraphCachedInfo (
707- placeholders , stack_traces , cudagraph_fail_reasons
708- )
697+ ]
698+ output = output_node (gm )
699+ # output args are tuple of first argument
700+ assert len (output .args ) == 1
701+ stack_traces = [
702+ (
703+ arg .stack_trace
704+ if isinstance (arg , torch .fx .node .Node )
705+ else None
706+ )
707+ for arg in output .args [0 ]
708+ ]
709+ cudagraph_fail_reasons = [s for b , s in cudagraph_tests if not b ]
710+ placeholders = tuple (get_placeholder_info (gm .graph ))
711+ cudagraph_info = CudagraphCachedInfo (
712+ placeholders , stack_traces , cudagraph_fail_reasons
713+ )
709714
710- compiled_graph .cudagraph_info = cudagraph_info
711- compiled_graph .inputs_to_check = inputs_to_check
712- compiled_graph .fx_kwargs = fx_kwargs
713- # TODO: should this be part of fx_kwargs
714- compiled_graph .boxed_forward_device_index = boxed_forward_device_index
715- return compiled_graph
715+ compiled_graph .cudagraph_info = cudagraph_info
716+ compiled_graph .inputs_to_check = inputs_to_check
717+ compiled_graph .fx_kwargs = fx_kwargs
718+ # TODO: should this be part of fx_kwargs
719+ compiled_graph .boxed_forward_device_index = boxed_forward_device_index
720+ return compiled_graph
716721
717722 with _WaitCounter ("pytorch.wait_counter.fx_codegen_and_compile" ).guard () as _ :
718723 if (
0 commit comments