Skip to content

Commit 28796f7

Browse files
xmfanpytorchmergebot
authored andcommitted
Redo D75092426: [internal] Expose additional metadata to compilation callbacks (pytorch#155063)
Originally pytorch#153596 --------------- Summary: via reverting D75708685 gate the ROCm failure Test Plan: Unit tests in OSS, sandcastle Rollback Plan: Bifferential Revision: D75894349 Pull Request resolved: pytorch#155063 Approved by: https://github.com/masnesral
1 parent 72453a6 commit 28796f7

File tree

10 files changed

+210
-71
lines changed

10 files changed

+210
-71
lines changed

test/dynamo/test_callback.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# Owner(s): ["module: dynamo"]
22

3+
import unittest
34
from unittest.mock import Mock
45

5-
from torch._dynamo.callback import callback_handler
6+
import torch
7+
from torch._dynamo.callback import callback_handler, CallbackArgs, CallbackTrigger
68
from torch._dynamo.test_case import run_tests, TestCase
9+
from torch._guards import CompileId
10+
from torch.testing._internal.common_utils import TEST_WITH_ROCM
11+
from torch.testing._internal.inductor_utils import HAS_CUDA
712

813

914
class CallbackTests(TestCase):
@@ -15,16 +20,22 @@ def setUp(self) -> None:
1520
callback_handler.register_end_callback(self._on_compile_end)
1621

1722
def tearDown(self) -> None:
18-
return super().tearDown()
1923
callback_handler.clear()
24+
return super().tearDown()
2025

2126
def test_callbacks_with_duplicate_prevention(self) -> None:
22-
with callback_handler.install_callbacks(), callback_handler.install_callbacks():
27+
trigger = CallbackTrigger.DYNAMO
28+
compile_id = CompileId(0, 0)
29+
with callback_handler.install_callbacks(
30+
trigger, compile_id
31+
), callback_handler.install_callbacks(trigger, compile_id):
2332
self._on_compile_start.assert_called_once()
2433
self._on_compile_end.assert_called_once()
2534

2635
def test_counter(self) -> None:
27-
with callback_handler.install_callbacks():
36+
trigger = CallbackTrigger.DYNAMO
37+
compile_id = CompileId(0, 0)
38+
with callback_handler.install_callbacks(trigger, compile_id):
2839
self.assertEqual(
2940
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
3041
1,
@@ -35,18 +46,95 @@ def test_counter(self) -> None:
3546

3647
def test_counter_assertion(self) -> None:
3748
callback_handler._CompilationCallbackHandler__pending_callbacks_counter -= 1
49+
with self.assertRaisesRegex(
50+
AssertionError, "Pending callbacks counter cannot become negative."
51+
):
52+
trigger = CallbackTrigger.DYNAMO
53+
compile_id = CompileId(0, 0)
54+
with callback_handler.install_callbacks(trigger, str(compile_id)):
55+
pass
56+
self.assertEqual(
57+
callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0
58+
)
59+
60+
@unittest.skipIf(
61+
TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs"
62+
)
63+
@unittest.skipIf(not HAS_CUDA, "requires triton")
64+
@torch._inductor.config.patch(force_disable_caches=True)
65+
def test_triggers(self) -> None:
66+
torch._dynamo.reset()
67+
order = []
68+
69+
def on_start(args: CallbackArgs):
70+
nonlocal order
71+
order.append(f"start={args}")
3872

39-
with self.assertRaises(
40-
AssertionError
41-
) as e, callback_handler.install_callbacks():
42-
pass
73+
def on_end(args: CallbackArgs):
74+
nonlocal order
75+
order.append(f"end={args}")
76+
77+
torch._dynamo.callback.on_compile_start(on_start)
78+
torch._dynamo.callback.on_compile_start(on_end)
79+
80+
class TinyModel(torch.nn.Module):
81+
def __init__(self):
82+
super().__init__()
83+
self.fc1 = torch.nn.Linear(10, 10)
84+
self.relu = torch.nn.ReLU()
85+
self.fc2 = torch.nn.Linear(10, 10)
86+
87+
def forward(self, x):
88+
temp = self.fc1(x)
89+
temp = self.relu(temp)
90+
torch._dynamo.graph_break()
91+
return self.fc2(temp)
92+
93+
model = TinyModel().to("cuda")
94+
compiled_model = torch.compile(model, mode="max-autotune")
95+
x = torch.randn(10, 10, device="cuda")
96+
97+
loss = compiled_model(x).sum()
98+
loss.backward()
99+
self.assertExpectedInline(
100+
"\n".join(order),
101+
"""\
102+
start=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='0/0')
103+
end=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='0/0')
104+
start=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='1/0')
105+
end=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='1/0')
106+
start=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='1/0')
107+
end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='1/0')
108+
start=CallbackArgs(callback_trigger=<CallbackTrigger.TRITON_AUTOTUNING: 3>, compile_id='1/0')
109+
end=CallbackArgs(callback_trigger=<CallbackTrigger.TRITON_AUTOTUNING: 3>, compile_id='1/0')
110+
start=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='0/0')
111+
end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='0/0')
112+
start=CallbackArgs(callback_trigger=<CallbackTrigger.TRITON_AUTOTUNING: 3>, compile_id='0/0')
113+
end=CallbackArgs(callback_trigger=<CallbackTrigger.TRITON_AUTOTUNING: 3>, compile_id='0/0')""", # noqa: B950
114+
)
115+
order.clear()
43116

44-
self.assertIn(
45-
"Pending callbacks counter cannot become negative.",
46-
str(e.exception),
117+
compiled_model.zero_grad()
118+
loss = compiled_model(x).sum()
119+
loss.backward()
120+
self.assertExpectedInline(
121+
"\n".join(order),
122+
"""\
123+
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
124+
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
125+
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
126+
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
127+
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
128+
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
129+
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
130+
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')""", # noqa: B950
47131
)
132+
order.clear()
48133

49-
callback_handler._CompilationCallbackHandler__pending_callbacks_counter += 1
134+
compiled_model.zero_grad()
135+
loss = compiled_model(x).sum()
136+
loss.backward()
137+
self.assertEqual(len(order), 0)
50138

51139

52140
if __name__ == "__main__":

test/dynamo/test_compile.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ def test_compilation_callback(self):
8181
torch._dynamo.reset()
8282

8383
@torch._dynamo.on_compile_start
84-
def start_callback():
84+
def start_callback(_):
8585
print("Compilation started.")
8686

8787
@torch._dynamo.on_compile_end
88-
def end_callback():
88+
def end_callback(_):
8989
print("Compilation ended.")
9090

9191
mod = ToyModel()
@@ -116,13 +116,13 @@ def test_compilation_callback_with_graph_break(self):
116116
counter = 0
117117

118118
@torch._dynamo.on_compile_start
119-
def start_callback():
119+
def start_callback(_):
120120
nonlocal counter
121121
counter += 1
122122
print(f"Counter = {counter}")
123123

124124
@torch._dynamo.on_compile_end
125-
def end_callback():
125+
def end_callback(_):
126126
nonlocal counter
127127
counter += 1
128128
print(f"Counter = {counter}")

torch/_dynamo/callback.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,44 @@ def my_end_callback():
2525
print("Compilation complete")
2626
"""
2727

28+
import enum
2829
import threading
2930
from collections.abc import Generator
3031
from contextlib import contextmanager
3132
from dataclasses import dataclass, field # noqa: F811
3233
from typing import Any, Callable
3334

3435

36+
class CallbackTrigger(enum.Enum):
37+
# most common case, dynamo attempts to trace a new frame
38+
DYNAMO = 1
39+
# backward compilation can be deferred to runtime
40+
LAZY_BACKWARD = 2
41+
# some backends autotune at runtime
42+
TRITON_AUTOTUNING = 3
43+
# cudagraphs record at runtime
44+
CUDAGRAPH_RECORDING = 4
45+
46+
47+
@dataclass
48+
class CallbackArgs:
49+
callback_trigger: CallbackTrigger
50+
compile_id: str
51+
52+
3553
@dataclass
3654
class CompilationCallbackHandler:
37-
start_callbacks: list[Callable[[], None]] = field(default_factory=list)
38-
end_callbacks: list[Callable[[], None]] = field(default_factory=list)
55+
start_callbacks: list[Callable[[CallbackArgs], None]] = field(default_factory=list)
56+
end_callbacks: list[Callable[[CallbackArgs], None]] = field(default_factory=list)
3957

4058
__pending_callbacks_counter: int = field(default=0, init=False, repr=False)
4159
__pending_callbacks_counter_lock: threading.Lock = field(
4260
default_factory=threading.Lock, init=False, repr=False
4361
)
4462

4563
def register_start_callback(
46-
self, callback: Callable[[], None]
47-
) -> Callable[[], None]:
64+
self, callback: Callable[[CallbackArgs], None]
65+
) -> Callable[[CallbackArgs], None]:
4866
"""
4967
Register a callback function to be called when the compilation starts.
5068
@@ -54,7 +72,9 @@ def register_start_callback(
5472
self.start_callbacks.append(callback)
5573
return callback
5674

57-
def register_end_callback(self, callback: Callable[[], None]) -> Callable[[], None]:
75+
def register_end_callback(
76+
self, callback: Callable[[CallbackArgs], None]
77+
) -> Callable[[CallbackArgs], None]:
5878
"""
5979
Register a callback function to be called when the compilation ends.
6080
@@ -64,7 +84,7 @@ def register_end_callback(self, callback: Callable[[], None]) -> Callable[[], No
6484
self.end_callbacks.append(callback)
6585
return callback
6686

67-
def remove_start_callback(self, callback: Callable[[], None]) -> None:
87+
def remove_start_callback(self, callback: Callable[[CallbackArgs], None]) -> None:
6888
"""
6989
Remove a registered start callback function.
7090
@@ -73,7 +93,7 @@ def remove_start_callback(self, callback: Callable[[], None]) -> None:
7393
"""
7494
self.start_callbacks.remove(callback)
7595

76-
def remove_end_callback(self, callback: Callable[[], None]) -> None:
96+
def remove_end_callback(self, callback: Callable[[CallbackArgs], None]) -> None:
7797
"""
7898
Remove a registered end callback function.
7999
@@ -82,29 +102,32 @@ def remove_end_callback(self, callback: Callable[[], None]) -> None:
82102
"""
83103
self.end_callbacks.remove(callback)
84104

85-
def run_start_callbacks(self) -> None:
105+
def run_start_callbacks(self, args: CallbackArgs) -> None:
86106
"""
87107
Execute all registered start callbacks.
88108
"""
89109
for callback in self.start_callbacks:
90-
callback()
110+
callback(args)
91111

92-
def run_end_callbacks(self) -> None:
112+
def run_end_callbacks(self, args: CallbackArgs) -> None:
93113
"""
94114
Execute all registered end callbacks.
95115
"""
96116
for callback in self.end_callbacks:
97-
callback()
117+
callback(args)
98118

99119
@contextmanager
100-
def install_callbacks(self) -> Generator[None, Any, Any]:
120+
def install_callbacks(
121+
self, trigger: CallbackTrigger, compile_id: str
122+
) -> Generator[None, Any, Any]:
101123
"""
102124
Context manager to install the callbacks and run them when the context is exited.
103125
"""
126+
args = CallbackArgs(trigger, compile_id)
104127
try:
105128
with self.__pending_callbacks_counter_lock:
106129
if self.__pending_callbacks_counter == 0:
107-
self.run_start_callbacks()
130+
self.run_start_callbacks(args)
108131
self.__pending_callbacks_counter += 1
109132
yield
110133
finally:
@@ -113,7 +136,7 @@ def install_callbacks(self) -> Generator[None, Any, Any]:
113136
"Pending callbacks counter cannot become negative."
114137
)
115138
if self.__pending_callbacks_counter == 1:
116-
self.run_end_callbacks()
139+
self.run_end_callbacks(args)
117140
self.__pending_callbacks_counter -= 1
118141

119142
def clear(self) -> None:
@@ -122,20 +145,25 @@ def clear(self) -> None:
122145
"""
123146
self.start_callbacks.clear()
124147
self.end_callbacks.clear()
148+
assert self.__pending_callbacks_counter == 0
125149

126150

127151
callback_handler = CompilationCallbackHandler()
128152

129153

130-
def on_compile_start(callback: Callable[[], None]) -> Callable[[], None]:
154+
def on_compile_start(
155+
callback: Callable[[CallbackArgs], None],
156+
) -> Callable[[CallbackArgs], None]:
131157
"""
132158
Decorator to register a callback function for the start of the compilation.
133159
"""
134160
callback_handler.register_start_callback(callback)
135161
return callback
136162

137163

138-
def on_compile_end(callback: Callable[[], None]) -> Callable[[], None]:
164+
def on_compile_end(
165+
callback: Callable[[CallbackArgs], None],
166+
) -> Callable[[CallbackArgs], None]:
139167
"""
140168
Decorator to register a callback function for the end of the compilation.
141169
"""

torch/_dynamo/convert_frame.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import torch
4949
import torch._logging
5050
from torch._C._dynamo.guards import GlobalStateGuard
51+
from torch._dynamo.callback import CallbackTrigger
5152
from torch._dynamo.distributed import get_compile_pg
5253
from torch._dynamo.symbolic_convert import TensorifyState
5354
from torch._guards import compile_context, CompileContext, CompileId, tracing
@@ -774,7 +775,11 @@ def compile_inner(
774775
transform: Callable[[list[Instruction], dict[str, Any]], Any],
775776
) -> ConvertFrameReturn:
776777
with contextlib.ExitStack() as stack:
777-
stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
778+
stack.enter_context(
779+
torch._dynamo.callback_handler.install_callbacks(
780+
CallbackTrigger.DYNAMO, str(CompileContext.current_compile_id())
781+
)
782+
)
778783
stack.enter_context(CompileTimeInstructionCounter.record())
779784
return _compile_inner(code, one_graph, hooks, transform)
780785

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.utils.dlpack
2222
from torch import Tensor
23+
from torch._dynamo.callback import callback_handler, CallbackTrigger
2324
from torch._dynamo.utils import CompileEventLogger, dynamo_timed, get_metrics_context
2425
from torch._guards import (
2526
compile_context,
@@ -2290,6 +2291,9 @@ def _backward_impl(ctx, all_args):
22902291
dynamo_compile_column_us="backward_cumulative_compile_time_us",
22912292
log_waitcounter=True,
22922293
waitcounter_name_override="entire_backward_compile",
2294+
), callback_handler.install_callbacks(
2295+
CallbackTrigger.LAZY_BACKWARD,
2296+
str(CompileContext.current_compile_id()),
22932297
):
22942298
CompileEventLogger.compilation_metric(is_forward=False)
22952299
# See Note: [Backward graph lazy lowering]

torch/_inductor/async_compile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def reload_kernel_in_parent():
346346
else:
347347
return future.result()
348348

349+
# Cache miss
349350
if is_parallel:
350351
# We want to support changing these env vars after (and while) the
351352
# process pool is running, so pass them to the subprocess to reset.

torch/_inductor/compile_fx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,6 @@ def compile_fx_inner(
707707
dynamo_compile_column_us="inductor_cumulative_compile_time_us",
708708
)
709709
)
710-
stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
711710
stack.enter_context(with_fresh_cache_if_config())
712711
stack.enter_context(DebugContext())
713712
CompileEventLogger.pt2_compile(

0 commit comments

Comments
 (0)