@@ -25,26 +25,44 @@ def my_end_callback():
2525 print("Compilation complete")
2626"""
2727
28+ import enum
2829import threading
2930from collections .abc import Generator
3031from contextlib import contextmanager
3132from dataclasses import dataclass , field # noqa: F811
3233from typing import Any , Callable
3334
3435
36+ class CallbackTrigger (enum .Enum ):
37+ # most common case, dynamo attempts to trace a new frame
38+ DYNAMO = 1
39+ # backward compilation can be deferred to runtime
40+ LAZY_BACKWARD = 2
41+ # some backends autotune at runtime
42+ TRITON_AUTOTUNING = 3
43+ # cudagraphs record at runtime
44+ CUDAGRAPH_RECORDING = 4
45+
46+
47+ @dataclass
48+ class CallbackArgs :
49+ callback_trigger : CallbackTrigger
50+ compile_id : str
51+
52+
3553@dataclass
3654class CompilationCallbackHandler :
37- start_callbacks : list [Callable [[], None ]] = field (default_factory = list )
38- end_callbacks : list [Callable [[], None ]] = field (default_factory = list )
55+ start_callbacks : list [Callable [[CallbackArgs ], None ]] = field (default_factory = list )
56+ end_callbacks : list [Callable [[CallbackArgs ], None ]] = field (default_factory = list )
3957
4058 __pending_callbacks_counter : int = field (default = 0 , init = False , repr = False )
4159 __pending_callbacks_counter_lock : threading .Lock = field (
4260 default_factory = threading .Lock , init = False , repr = False
4361 )
4462
4563 def register_start_callback (
46- self , callback : Callable [[], None ]
47- ) -> Callable [[], None ]:
64+ self , callback : Callable [[CallbackArgs ], None ]
65+ ) -> Callable [[CallbackArgs ], None ]:
4866 """
4967 Register a callback function to be called when the compilation starts.
5068
@@ -54,7 +72,9 @@ def register_start_callback(
5472 self .start_callbacks .append (callback )
5573 return callback
5674
57- def register_end_callback (self , callback : Callable [[], None ]) -> Callable [[], None ]:
75+ def register_end_callback (
76+ self , callback : Callable [[CallbackArgs ], None ]
77+ ) -> Callable [[CallbackArgs ], None ]:
5878 """
5979 Register a callback function to be called when the compilation ends.
6080
@@ -64,7 +84,7 @@ def register_end_callback(self, callback: Callable[[], None]) -> Callable[[], No
6484 self .end_callbacks .append (callback )
6585 return callback
6686
67- def remove_start_callback (self , callback : Callable [[], None ]) -> None :
87+ def remove_start_callback (self , callback : Callable [[CallbackArgs ], None ]) -> None :
6888 """
6989 Remove a registered start callback function.
7090
@@ -73,7 +93,7 @@ def remove_start_callback(self, callback: Callable[[], None]) -> None:
7393 """
7494 self .start_callbacks .remove (callback )
7595
76- def remove_end_callback (self , callback : Callable [[], None ]) -> None :
96+ def remove_end_callback (self , callback : Callable [[CallbackArgs ], None ]) -> None :
7797 """
7898 Remove a registered end callback function.
7999
@@ -82,29 +102,32 @@ def remove_end_callback(self, callback: Callable[[], None]) -> None:
82102 """
83103 self .end_callbacks .remove (callback )
84104
85- def run_start_callbacks (self ) -> None :
105+ def run_start_callbacks (self , args : CallbackArgs ) -> None :
86106 """
87107 Execute all registered start callbacks.
88108 """
89109 for callback in self .start_callbacks :
90- callback ()
110+ callback (args )
91111
92- def run_end_callbacks (self ) -> None :
112+ def run_end_callbacks (self , args : CallbackArgs ) -> None :
93113 """
94114 Execute all registered end callbacks.
95115 """
96116 for callback in self .end_callbacks :
97- callback ()
117+ callback (args )
98118
99119 @contextmanager
100- def install_callbacks (self ) -> Generator [None , Any , Any ]:
120+ def install_callbacks (
121+ self , trigger : CallbackTrigger , compile_id : str
122+ ) -> Generator [None , Any , Any ]:
101123 """
102124 Context manager to install the callbacks and run them when the context is exited.
103125 """
126+ args = CallbackArgs (trigger , compile_id )
104127 try :
105128 with self .__pending_callbacks_counter_lock :
106129 if self .__pending_callbacks_counter == 0 :
107- self .run_start_callbacks ()
130+ self .run_start_callbacks (args )
108131 self .__pending_callbacks_counter += 1
109132 yield
110133 finally :
@@ -113,7 +136,7 @@ def install_callbacks(self) -> Generator[None, Any, Any]:
113136 "Pending callbacks counter cannot become negative."
114137 )
115138 if self .__pending_callbacks_counter == 1 :
116- self .run_end_callbacks ()
139+ self .run_end_callbacks (args )
117140 self .__pending_callbacks_counter -= 1
118141
119142 def clear (self ) -> None :
@@ -122,20 +145,25 @@ def clear(self) -> None:
122145 """
123146 self .start_callbacks .clear ()
124147 self .end_callbacks .clear ()
148+ assert self .__pending_callbacks_counter == 0
125149
126150
127151callback_handler = CompilationCallbackHandler ()
128152
129153
130- def on_compile_start (callback : Callable [[], None ]) -> Callable [[], None ]:
154+ def on_compile_start (
155+ callback : Callable [[CallbackArgs ], None ],
156+ ) -> Callable [[CallbackArgs ], None ]:
131157 """
132158 Decorator to register a callback function for the start of the compilation.
133159 """
134160 callback_handler .register_start_callback (callback )
135161 return callback
136162
137163
138- def on_compile_end (callback : Callable [[], None ]) -> Callable [[], None ]:
164+ def on_compile_end (
165+ callback : Callable [[CallbackArgs ], None ],
166+ ) -> Callable [[CallbackArgs ], None ]:
139167 """
140168 Decorator to register a callback function for the end of the compilation.
141169 """
0 commit comments