@@ -140,17 +140,224 @@ def add_fn(x, y):
140140 print (f"Vector addition of\n X:\t { x } \n Y:\t { y } \n is equal to\n { out } " )
141141
142142######################################################################
143- # Composibility and Limitations
143+ # Composability
144+ # -------------------------------------------------------------------
145+ #
146+ # User-defined Triton kernels do not automatically support all PyTorch
147+ # subsystems. This can be seen in the following use cases:
148+ #
149+ # - Adding a CPU fallback
150+ # - Adding a ``FlopCounter`` formula
151+ # - Composing with Tensor Subclasses
152+ #
153+ # To compose with additional PyTorch subsystems, use ``torch.library.triton_op``.
154+ #
155+ # ``triton_op is`` a structured way of defining a custom operator that is backed by one
156+ # or more Triton kernels: like regular custom operators (``torch.library.custom_op``),
157+ # you are able to specify the interactions with PyTorch subsystems via ``torch.library``.
158+ # However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to
159+ # ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations.
160+ #
161+ # Here’s a chart of which API to use when integrating Triton kernels with PyTorch.
162+ #
163+ # .. list-table::
164+ # :header-rows: 1
165+ #
166+ # * -
167+ # - Triton kernel (no explicit ``torch.library`` wrapper)
168+ # - ``torch.library.triton_op``
169+ # - ``torch.library.custom_op``
170+ # * - Supports inference
171+ # - Yes
172+ # - Yes
173+ # - Yes
174+ # * - Supports training
175+ # - In the majority of cases
176+ # - Yes
177+ # - Yes
178+ # * - Supports ``torch.compile``
179+ # - Yes
180+ # - Yes
181+ # - Yes
182+ # * - Supports ``torch.compile(fullgraph=True)``
183+ # - In the majority of cases
184+ # - In the majority of cases
185+ # - In all cases
186+ # * - Does torch.compile trace into the implementation?
187+ # - Yes
188+ # - Yes
189+ # - No
190+ # * - Supports AOTInductor
191+ # - Yes
192+ # - Yes
193+ # - No
194+ # * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses
195+ # - No
196+ # - Yes
197+ # - Yes
198+
199+ ######################################################################
200+ # Wrapping Triton kernels with ``triton_op``
201+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
202+ #
203+ # Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels.
204+ # Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel.
205+
206+ from torch .library import triton_op , wrap_triton
207+
208+ @triton_op ("mylib::mysin" , mutates_args = {})
209+ def mysin (x : torch .Tensor ) -> torch .Tensor :
210+ out = torch .empty_like (x )
211+ n_elements = x .numel ()
212+ wrap_triton (sin_kernel )[(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
213+ return out
214+
215+ @triton .jit
216+ def sin_kernel (
217+ in_ptr0 ,
218+ out_ptr ,
219+ n_elements ,
220+ BLOCK_SIZE : "tl.constexpr" ,
221+ ):
222+ pid = tl .program_id (axis = 0 )
223+ block_start = pid * BLOCK_SIZE
224+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
225+ mask = offsets < n_elements
226+ x = tl .load (in_ptr0 + offsets , mask = mask )
227+ output = tl .sin (x )
228+ tl .store (out_ptr + offsets , output , mask = mask )
229+
230+ def sin_triton (x ):
231+ out = torch .empty_like (x )
232+ n_elements = x .numel ()
233+ sin_kernel [(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
234+ return out
235+
236+ ######################################################################
237+ # You can invoke the ``triton_op`` in one of the following two ways.
238+
239+ x = torch .randn (3 , device = "cuda" )
240+ y = mysin (x )
241+ z = torch .ops .mylib .mysin .default (x )
242+
243+ assert torch .allclose (y , x .sin ())
244+ assert torch .allclose (z , x .sin ())
245+
246+ ######################################################################
247+ # The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``.
248+
249+ y = torch .compile (mysin )(x )
250+ assert torch .allclose (y , x .sin ())
251+
252+ ######################################################################
253+ # Adding training support
254+ # ^^^^^^^^^^^^^^^^^^^^^^^
255+ #
256+ # Use ``register_autograd`` to add an autograd formula for the ``triton_op``.
257+ # Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
258+ # with ``torch.compile``).
259+
260+ def backward (ctx , grad_output ):
261+ x , = ctx .saved_tensors
262+ return grad_input * x .cos ()
263+
264+ def setup_context (ctx , inputs , output ):
265+ x , = inputs
266+ ctx .save_for_backward (x )
267+
268+ mysin .register_autograd (backward , setup_context = setup_context )
269+
270+ ######################################################################
271+ # Note that the backward must be a composition of PyTorch-understood operators.
272+ # If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well:
273+
274+ @triton .jit
275+ def cos_kernel (
276+ in_ptr0 ,
277+ out_ptr ,
278+ n_elements ,
279+ BLOCK_SIZE : "tl.constexpr" ,
280+ ):
281+ pid = tl .program_id (axis = 0 )
282+ block_start = pid * BLOCK_SIZE
283+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
284+ mask = offsets < n_elements
285+ x = tl .load (in_ptr0 + offsets , mask = mask )
286+ output = tl .cos (x )
287+ tl .store (out_ptr + offsets , output , mask = mask )
288+
289+ @triton_op ("mylib::mycos" , mutates_args = {})
290+ def mycos (x : torch .Tensor ) -> torch .Tensor :
291+ out = torch .empty_like (x )
292+ n_elements = x .numel ()
293+ wrap_triton (cos_kernel )[(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
294+ return out
295+
296+ def backward (ctx , grad_output ):
297+ x , = ctx .saved_tensors
298+ return grad_input * mycos (x )
299+
300+ def setup_context (ctx , inputs , output ):
301+ x , = inputs
302+ ctx .save_for_backward (x )
303+
304+ mysin .register_autograd (backward , setup_context = setup_context )
305+
306+ ######################################################################
307+ # Adding a CPU Fallback
308+ # ^^^^^^^^^^^^^^^^^^^^^
309+ # Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``:
310+
311+ @mysin .register_kernel ("cpu" )
312+ def _ (x ):
313+ return torch .sin (x )
314+
315+ x = torch .randn (3 )
316+ y = mysin (x )
317+ assert torch .allclose (y , x .sin ())
318+
319+ ######################################################################
320+ # The fallback must be composed of PyTorch operators.
321+
322+ ######################################################################
323+ # Adding a FlopCounter Formula
324+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
325+ #
326+ # To specify how many flops the triton kernel reports under PyTorch's flop counter,
327+ # use ``register_flop_formula``.
328+
329+ from torch .utils .flop_counter import FlopCounterMode , register_flop_formula
330+
331+ @register_flop_formula (torch .ops .mylib .mysin )
332+ def _ (x_shape ):
333+ numel = 1
334+ for s in x_shape :
335+ numel *= s
336+ return numel
337+
338+ x = torch .randn (3 , device = "cuda" )
339+
340+ #########################################################
341+ # ``FlopCounterMode`` requires `tabulate <https://pypi.org/project/tabulate/>`__.
342+ # Before running the code below, make sure you have ``tabulate`` installed or install by
343+ # running ``pip install tabulate``.
344+ #
345+ # >>> with FlopCounterMode() as flop_counter:
346+ # >>> y = mysin(x)
347+
348+ ######################################################################
349+ # Limitations
144350# --------------------------------------------------------------------
145351#
146352# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147353# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148354# You can use these features together to build complex, high-performance models.
149355#
356+ # PyTorch 2.6 added ``torch.library.triton_op``, which adds support for
357+ # user-defined Triton kernels in tensor subclasses and other advanced features.
358+ #
150359# However, there are certain limitations to be aware of:
151360#
152- # * **Tensor Subclasses:** Currently, there is no support for
153- # tensor subclasses and other advanced features.
154361# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155362# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This
156363# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
0 commit comments