-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Update user-defined triton kernels tutorial with new torch.library.triton_op #3227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
c307876
5b03d04
dd359f7
20e86d0
72d6007
58bc83d
665cd6e
4f9eede
e5aaa9a
9de60b4
a165a7e
3620928
e8c6763
cbe7f04
844ff02
7b0fb12
db1c64e
a88cf38
bc9c98e
9291e85
2ee6553
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,17 +140,224 @@ def add_fn(x, y): | |
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") | ||
|
||
###################################################################### | ||
# Composibility and Limitations | ||
# Composability | ||
# ------------------------------------------------------------------- | ||
# | ||
# User-defined Triton kernels do not automatically support all PyTorch | ||
# subsystems. This can be seen in the following use cases: | ||
|
||
# - Adding a CPU fallback | ||
# - Adding a ``FlopCounter`` formula | ||
# - Composing with Tensor Subclasses | ||
# | ||
# To compose with additional PyTorch subsystems, use ``torch.library.triton_op``. | ||
# | ||
# ``triton_op is`` a structured way of defining a custom operator that is backed by one | ||
# or more Triton kernels: like regular custom operators (``torch.library.custom_op``), | ||
# you are able to specify the interactions with PyTorch subsystems via ``torch.library``. | ||
# However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to | ||
# ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations. | ||
# | ||
# Here’s a chart of which API to use when integrating Triton kernels with PyTorch. | ||
# | ||
# .. list-table:: | ||
# :header-rows: 1 | ||
# | ||
# * - | ||
# - Triton kernel (no explicit ``torch.library`` wrapper) | ||
# - ``torch.library.triton_op`` | ||
# - ``torch.library.custom_op`` | ||
# * - Supports inference | ||
# - Yes | ||
# - Yes | ||
# - Yes | ||
# * - Supports training | ||
# - In the majority of cases | ||
# - Yes | ||
# - Yes | ||
# * - Supports ``torch.compile`` | ||
# - Yes | ||
# - Yes | ||
# - Yes | ||
# * - Supports ``torch.compile(fullgraph=True)`` | ||
# - In the majority of cases | ||
# - In the majority of cases | ||
# - In all cases | ||
# * - Does torch.compile trace into the implementation? | ||
# - Yes | ||
# - Yes | ||
# - No | ||
# * - Supports AOTInductor | ||
# - Yes | ||
# - Yes | ||
# - No | ||
# * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses | ||
# - No | ||
# - Yes | ||
# - Yes | ||
|
||
###################################################################### | ||
# Wrapping Triton kernels with ``triton_op`` | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels. | ||
# Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel. | ||
|
||
from torch.library import triton_op, wrap_triton | ||
|
||
@triton_op("mylib::mysin", mutates_args={}) | ||
def mysin(x: torch.Tensor) -> torch.Tensor: | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
|
||
@triton.jit | ||
def sin_kernel( | ||
in_ptr0, | ||
out_ptr, | ||
n_elements, | ||
BLOCK_SIZE: "tl.constexpr", | ||
): | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(in_ptr0 + offsets, mask=mask) | ||
output = tl.sin(x) | ||
tl.store(out_ptr + offsets, output, mask=mask) | ||
|
||
def sin_triton(x): | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
Comment on lines
+230
to
+234
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is the user-defined triton kernel w/ no |
||
|
||
###################################################################### | ||
# You can invoke the ``triton_op`` in one of the following two ways. | ||
|
||
x = torch.randn(3, device="cuda") | ||
y = mysin(x) | ||
z = torch.ops.mylib.mysin.default(x) | ||
|
||
assert torch.allclose(y, x.sin()) | ||
assert torch.allclose(z, x.sin()) | ||
|
||
###################################################################### | ||
# The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``. | ||
|
||
y = torch.compile(mysin)(x) | ||
assert torch.allclose(y, x.sin()) | ||
|
||
###################################################################### | ||
# Adding training support | ||
# ^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# Use ``register_autograd`` to add an autograd formula for the ``triton_op``. | ||
# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns | ||
# with ``torch.compile``). | ||
|
||
def backward(ctx, grad_output): | ||
x, = ctx.saved_tensors | ||
return grad_input * x.cos() | ||
|
||
def setup_context(ctx, inputs, output): | ||
x, = inputs | ||
ctx.save_for_backward(x) | ||
|
||
mysin.register_autograd(backward, setup_context=setup_context) | ||
|
||
###################################################################### | ||
# Note that the backward must be a composition of PyTorch-understood operators. | ||
# If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well: | ||
|
||
@triton.jit | ||
def cos_kernel( | ||
in_ptr0, | ||
out_ptr, | ||
n_elements, | ||
BLOCK_SIZE: "tl.constexpr", | ||
): | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(in_ptr0 + offsets, mask=mask) | ||
output = tl.cos(x) | ||
tl.store(out_ptr + offsets, output, mask=mask) | ||
|
||
@triton_op("mylib::mycos", mutates_args={}) | ||
def mycos(x: torch.Tensor) -> torch.Tensor: | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
|
||
def backward(ctx, grad_output): | ||
x, = ctx.saved_tensors | ||
return grad_input * mycos(x) | ||
|
||
def setup_context(ctx, inputs, output): | ||
x, = inputs | ||
ctx.save_for_backward(x) | ||
|
||
mysin.register_autograd(backward, setup_context=setup_context) | ||
|
||
###################################################################### | ||
# Adding a CPU Fallback | ||
# ^^^^^^^^^^^^^^^^^^^^^ | ||
# Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``: | ||
|
||
@mysin.register_kernel("cpu") | ||
def _(x): | ||
return torch.sin(x) | ||
|
||
x = torch.randn(3) | ||
y = mysin(x) | ||
assert torch.allclose(y, x.sin()) | ||
|
||
###################################################################### | ||
# The fallback must be composed of PyTorch operators. | ||
|
||
###################################################################### | ||
# Adding a FlopCounter Formula | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# To specify how many flops the triton kernel reports under PyTorch's flop counter, | ||
# use ``register_flop_formula``. | ||
|
||
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula | ||
|
||
@register_flop_formula(torch.ops.mylib.mysin) | ||
def _(x_shape): | ||
numel = 1 | ||
for s in x_shape: | ||
numel *= s | ||
return numel | ||
|
||
x = torch.randn(3, device="cuda") | ||
|
||
######################################################### | ||
# ``FlopCounterMode`` requires `tabulate <https://pypi.org/project/tabulate/>`__. | ||
# Before running the code below, make sure you have ``tabulate`` installed or install by | ||
# running ``pip install tabulate``. | ||
# | ||
# >>> with FlopCounterMode() as flop_counter: | ||
# >>> y = mysin(x) | ||
|
||
###################################################################### | ||
# Limitations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having this title be limitations, and then having two paragraphs of not-limitations is a bit weird |
||
# -------------------------------------------------------------------- | ||
# | ||
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile`` | ||
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor. | ||
# You can use these features together to build complex, high-performance models. | ||
# | ||
# PyTorch 2.6 added ``torch.library.triton_op``, which adds support for | ||
# user-defined Triton kernels in tensor subclasses and other advanced features. | ||
# | ||
# However, there are certain limitations to be aware of: | ||
# | ||
# * **Tensor Subclasses:** Currently, there is no support for | ||
# tensor subclasses and other advanced features. | ||
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or | ||
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This | ||
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used | ||
|
Uh oh!
There was an error while loading. Please reload this page.