Skip to content

Commit 48e9ffc

Browse files
ppanchaliapytorchmergebot
authored andcommitted
Unify on dynamo_compile as the overall wait counter (pytorch#150293)
Summary: dynamo_compile for the most part has been accounting for compile time except autotuning. all_compilation_types had earlier been injected on fx_codegen_and_compile, which was incorrect. Add autotuining to dynamo and deprcate all_compilation_types counter. Differential Revision: D72145447 Pull Request resolved: pytorch#150293 Approved by: https://github.com/masnesral, https://github.com/jamesjwu
1 parent 36f2d0a commit 48e9ffc

File tree

4 files changed

+19
-21
lines changed

4 files changed

+19
-21
lines changed

torch/_dynamo/convert_frame.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,6 @@ def compile_inner(
774774
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
775775
)
776776
)
777-
stack.enter_context(
778-
_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()
779-
)
780777
stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
781778
stack.enter_context(CompileTimeInstructionCounter.record())
782779
return _compile_inner(code, one_graph, hooks, transform)
@@ -957,7 +954,9 @@ def count_args(code: CodeType) -> int:
957954
chromium_event_timed(
958955
"dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
959956
),
957+
_WaitCounter("pytorch.wait_counter.entire_forward_compile").guard(),
960958
metrics_context,
959+
_WaitCounter("pytorch.wait_counter.dynamo_compile").guard(),
961960
):
962961
restart_reasons: set[str] = set()
963962
# This is shared across restarts

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch._prims_common import CUDARngStateHelper
3232
from torch._subclasses import FakeTensor
3333
from torch.fx.experimental._backward_state import BackwardState
34+
from torch.monitor import _WaitCounter
3435
from torch.multiprocessing.reductions import StorageWeakRef
3536
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
3637

@@ -2225,7 +2226,9 @@ def _backward_impl(ctx, all_args):
22252226
dynamo_compile_column_us="backward_cumulative_compile_time_us",
22262227
log_waitcounter=True,
22272228
waitcounter_name_override="entire_backward_compile",
2228-
):
2229+
), _WaitCounter(
2230+
"pytorch.wait_counter.dynamo_compile"
2231+
).guard():
22292232
CompileEventLogger.compilation_metric(is_forward=False)
22302233
# See Note: [Backward graph lazy lowering]
22312234
CompiledFunction.compiled_bw = aot_config.bw_compiler(

torch/_inductor/compile_fx.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -620,15 +620,6 @@ def compile_fx_inner(
620620
dynamo_compile_column_us="inductor_cumulative_compile_time_us",
621621
)
622622
)
623-
# NB: Why is this the dynamo_compile counter? The rule here is that
624-
# if it gets an entry in the dynamo_compile table, we also want to
625-
# tick up the wait counter. We have to displeasingly manually trigger
626-
# the counter here because we may dropped into compile_fx directly
627-
# from lazy backwards compilation.
628-
stack.enter_context(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard())
629-
stack.enter_context(
630-
_WaitCounter("pytorch.wait_counter.all_compilation_types").guard()
631-
)
632623

633624
if torch._dynamo.callback_handler.prevent_duplicate_callbacks:
634625
stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
@@ -691,7 +682,6 @@ def _compile_fx_inner(
691682

692683
with (
693684
_WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _,
694-
_WaitCounter("pytorch.wait_counter.all_compilation_types").guard(),
695685
):
696686
use_cache = (
697687
not config.force_disable_caches

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import torch
3333
from torch._prims_common import compute_required_storage_length
34+
from torch.monitor import _WaitCounter
3435
from torch.utils._ordered_set import OrderedSet
3536

3637
from ..triton_bundler import TritonBundler
@@ -815,13 +816,18 @@ def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
815816
return self.maybe_clone_args(OrderedSet(), *args, **kwargs)
816817

817818
def benchmark_all_configs(self, *args, **kwargs):
818-
with dynamo_timed(
819-
"CachingAutotuner.benchmark_all_configs",
820-
log_pt2_compile_event=True,
821-
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
822-
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
823-
compile_id=self.compile_id,
824-
is_backward=self.is_backward,
819+
with (
820+
dynamo_timed(
821+
"CachingAutotuner.benchmark_all_configs",
822+
log_pt2_compile_event=True,
823+
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
824+
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
825+
compile_id=self.compile_id,
826+
is_backward=self.is_backward,
827+
log_waitcounter=True,
828+
waitcounter_name_override="triton_autotuner",
829+
),
830+
_WaitCounter("pytorch.wait_counter.dynamo_compile").guard(),
825831
):
826832
timings = {
827833
launcher: self.bench(launcher, *args, **kwargs)

0 commit comments

Comments
 (0)