@@ -140,17 +140,220 @@ def add_fn(x, y):
140140 print (f"Vector addition of\n X:\t { x } \n Y:\t { y } \n is equal to\n { out } " )
141141
142142######################################################################
143- # Composibility and Limitations
143+ # Composability
144+ # -------------------------------------------------------------------
145+ #
146+ # User-defined triton kernels do not automatically support all PyTorch
147+ # subsystems, like in the following use cases:
148+ # - Adding a CPU fallback
149+ # - Adding a ``FlopCounter`` formula
150+ # - Composing with Tensor Subclasses
151+ #
152+ # To compose with additional PyTorch subsystems, use ``torch.library.triton_op``.
153+ #
154+ # triton_op is a structured way of defining a custom operator that is backed by one
155+ # or more triton kernels: like regular custom operators (``torch.library.custom_op``),
156+ # you are able to specify the interactions with PyTorch subsystems via ``torch.library``.
157+ # However, unlike ``torch.library.custom_op``, which creates opaque callables w.r.t.
158+ # ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations.
159+ #
160+ # Here’s a chart of which API to use when integrating triton kernels with PyTorch.
161+ #
162+ # .. list-table::
163+ # :header-rows: 1
164+ #
165+ # * -
166+ # - triton kernel (no explicit torch.library wrapper)
167+ # - torch.library.triton_op
168+ # - torch.library.custom_op
169+ # * - Supports inference
170+ # - Yes
171+ # - Yes
172+ # - Yes
173+ # * - Supports training
174+ # - In the majority of cases
175+ # - Yes
176+ # - Yes
177+ # * - Supports torch.compile
178+ # - Yes
179+ # - Yes
180+ # - Yes
181+ # * - Supports torch.compile(fullgraph=True)
182+ # - In the majority of cases
183+ # - In the majority of cases
184+ # - In all cases
185+ # * - Does torch.compile trace into the implementation?
186+ # - Yes
187+ # - Yes
188+ # - No
189+ # * - Supports AOTInductor
190+ # - Yes
191+ # - Yes
192+ # - No
193+ # * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses
194+ # - No
195+ # - Yes
196+ # - Yes
197+
198+ ######################################################################
199+ # Wrapping triton kernels with triton_op
200+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
201+ #
202+ # Use ``torch.library.triton_op`` to wrap a function that may invoke one or more triton kernels.
203+ # Use ``torch.library.wrap_triton`` to wrap the calls to the triton kernel.
204+
205+ from torch .library import triton_op , wrap_triton
206+
207+ @triton_op ("mylib::mysin" , mutates_args = {})
208+ def mysin (x : torch .Tensor ) -> torch .Tensor :
209+ out = torch .empty_like (x )
210+ n_elements = x .numel ()
211+ wrap_triton (sin_kernel )[(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
212+ return out
213+
214+ @triton .jit
215+ def sin_kernel (
216+ in_ptr0 ,
217+ out_ptr ,
218+ n_elements ,
219+ BLOCK_SIZE : "tl.constexpr" ,
220+ ):
221+ pid = tl .program_id (axis = 0 )
222+ block_start = pid * BLOCK_SIZE
223+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
224+ mask = offsets < n_elements
225+ x = tl .load (in_ptr0 + offsets , mask = mask )
226+ output = tl .sin (x )
227+ tl .store (out_ptr + offsets , output , mask = mask )
228+
229+ def sin_triton (x ):
230+ out = torch .empty_like (x )
231+ n_elements = x .numel ()
232+ sin_kernel [(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
233+ return out
234+
235+ ######################################################################
236+ # You can invoke the ``triton_op`` in one of the following two ways.
237+
238+ x = torch .randn (3 , device = "cuda" )
239+ y = mysin (x )
240+ z = torch .ops .mylib .mysin .default (x )
241+
242+ assert torch .allclose (y , x .sin ())
243+ assert torch .allclose (z , x .sin ())
244+
245+ ######################################################################
246+ # The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``.
247+
248+ y = torch .compile (mysin )(x )
249+ assert torch .allclose (y , x .sin ())
250+
251+ ######################################################################
252+ # Adding training support
253+ # ^^^^^^^^^^^^^^^^^^^^^^^
254+ #
255+ # Use ``register_autograd`` to add an autograd formula for the ``triton_op``.
256+ # Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
257+ # with ``torch.compile``).
258+
259+ def backward (ctx , grad_output ):
260+ x , = ctx .saved_tensors
261+ return grad_input * x .cos ()
262+
263+ def setup_context (ctx , inputs , output ):
264+ x , = inputs
265+ ctx .save_for_backward (x )
266+
267+ mysin .register_autograd (backward , setup_context = setup_context )
268+
269+ ######################################################################
270+ # Note that the backward must be a composition of PyTorch-understood operators.
271+ # If you want the backward to call triton kernels, then those must be wrapped in ``triton_op`` as well:
272+
273+ @triton .jit
274+ def cos_kernel (
275+ in_ptr0 ,
276+ out_ptr ,
277+ n_elements ,
278+ BLOCK_SIZE : "tl.constexpr" ,
279+ ):
280+ pid = tl .program_id (axis = 0 )
281+ block_start = pid * BLOCK_SIZE
282+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
283+ mask = offsets < n_elements
284+ x = tl .load (in_ptr0 + offsets , mask = mask )
285+ output = tl .cos (x )
286+ tl .store (out_ptr + offsets , output , mask = mask )
287+
288+ @triton_op ("mylib::mycos" , mutates_args = {})
289+ def mycos (x : torch .Tensor ) -> torch .Tensor :
290+ out = torch .empty_like (x )
291+ n_elements = x .numel ()
292+ wrap_triton (cos_kernel )[(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
293+ return out
294+
295+ def backward (ctx , grad_output ):
296+ x , = ctx .saved_tensors
297+ return grad_input * mycos (x )
298+
299+ def setup_context (ctx , inputs , output ):
300+ x , = inputs
301+ ctx .save_for_backward (x )
302+
303+ mysin .register_autograd (backward , setup_context = setup_context )
304+
305+ ######################################################################
306+ # Adding a CPU Fallback
307+ # ^^^^^^^^^^^^^^^^^^^^^
308+ # triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``:
309+
310+ @mysin .register_kernel ("cpu" )
311+ def _ (x ):
312+ return torch .sin (x )
313+
314+ x = torch .randn (3 )
315+ y = mysin (x )
316+ assert torch .allclose (y , x .sin ())
317+
318+ ######################################################################
319+ # The fallback must be composed of PyTorch operators.
320+
321+ ######################################################################
322+ # Adding a FlopCounter Formula
323+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
324+ #
325+ # To specify how many flops the triton kernel reports under PyTorch's flop counter,
326+ # use ``register_flop_formula``.
327+
328+ from torch .utils .flop_counter import FlopCounterMode , register_flop_formula
329+
330+ @register_flop_formula (torch .ops .mylib .mysin )
331+ def _ (x_shape ):
332+ numel = 1
333+ for s in x_shape :
334+ numel *= s
335+ return numel
336+
337+ x = torch .randn (3 , device = "cuda" )
338+
339+ # NB: FlopCounterMode requires tabulate.
340+ #
341+ # >>> with FlopCounterMode() as flop_counter:
342+ # >>> y = mysin(x)
343+
344+ ######################################################################
345+ # Limitations
144346# --------------------------------------------------------------------
145347#
146348# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147349# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148350# You can use these features together to build complex, high-performance models.
149351#
352+ # PyTorch 2.6 added ``torch.library.triton_op``, which adds support for
353+ # user-defined Triton kernels in tensor subclasses and other advanced features.
354+ #
150355# However, there are certain limitations to be aware of:
151356#
152- # * **Tensor Subclasses:** Currently, there is no support for
153- # tensor subclasses and other advanced features.
154357# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155358# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This
156359# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
0 commit comments