-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Update user-defined triton kernels tutorial with new torch.library.triton_op #3227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c307876
5b03d04
dd359f7
20e86d0
72d6007
58bc83d
665cd6e
4f9eede
e5aaa9a
9de60b4
a165a7e
3620928
e8c6763
cbe7f04
844ff02
7b0fb12
db1c64e
a88cf38
bc9c98e
9291e85
2ee6553
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,17 +140,220 @@ def add_fn(x, y): | |
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") | ||
|
||
###################################################################### | ||
# Composibility and Limitations | ||
# Composability | ||
# ------------------------------------------------------------------- | ||
# | ||
# User-defined triton kernels do not automatically support all PyTorch | ||
# subsystems, like in the following use cases: | ||
svekars marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# - Adding a CPU fallback | ||
# - Adding a ``FlopCounter`` formula | ||
# - Composing with Tensor Subclasses | ||
# | ||
# To compose with additional PyTorch subsystems, use ``torch.library.triton_op``. | ||
# | ||
# triton_op is a structured way of defining a custom operator that is backed by one | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# or more triton kernels: like regular custom operators (``torch.library.custom_op``), | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# you are able to specify the interactions with PyTorch subsystems via ``torch.library``. | ||
# However, unlike ``torch.library.custom_op``, which creates opaque callables w.r.t. | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations. | ||
# | ||
# Here’s a chart of which API to use when integrating triton kernels with PyTorch. | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# .. list-table:: | ||
# :header-rows: 1 | ||
# | ||
# * - | ||
# - triton kernel (no explicit torch.library wrapper) | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# - torch.library.triton_op | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# - torch.library.custom_op | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# * - Supports inference | ||
# - Yes | ||
# - Yes | ||
# - Yes | ||
# * - Supports training | ||
# - In the majority of cases | ||
# - Yes | ||
# - Yes | ||
# * - Supports torch.compile | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# - Yes | ||
# - Yes | ||
# - Yes | ||
# * - Supports torch.compile(fullgraph=True) | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# - In the majority of cases | ||
# - In the majority of cases | ||
# - In all cases | ||
# * - Does torch.compile trace into the implementation? | ||
# - Yes | ||
# - Yes | ||
# - No | ||
# * - Supports AOTInductor | ||
# - Yes | ||
# - Yes | ||
# - No | ||
# * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses | ||
# - No | ||
# - Yes | ||
# - Yes | ||
|
||
###################################################################### | ||
# Wrapping triton kernels with triton_op | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more triton kernels. | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Use ``torch.library.wrap_triton`` to wrap the calls to the triton kernel. | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from torch.library import triton_op, wrap_triton | ||
|
||
@triton_op("mylib::mysin", mutates_args={}) | ||
def mysin(x: torch.Tensor) -> torch.Tensor: | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
|
||
@triton.jit | ||
def sin_kernel( | ||
in_ptr0, | ||
out_ptr, | ||
n_elements, | ||
BLOCK_SIZE: "tl.constexpr", | ||
): | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(in_ptr0 + offsets, mask=mask) | ||
output = tl.sin(x) | ||
tl.store(out_ptr + offsets, output, mask=mask) | ||
|
||
def sin_triton(x): | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
Comment on lines
+230
to
+234
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is the user-defined triton kernel w/ no |
||
|
||
###################################################################### | ||
# You can invoke the ``triton_op`` in one of the following two ways. | ||
|
||
x = torch.randn(3, device="cuda") | ||
y = mysin(x) | ||
z = torch.ops.mylib.mysin.default(x) | ||
|
||
assert torch.allclose(y, x.sin()) | ||
assert torch.allclose(z, x.sin()) | ||
|
||
###################################################################### | ||
# The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``. | ||
|
||
y = torch.compile(mysin)(x) | ||
assert torch.allclose(y, x.sin()) | ||
|
||
###################################################################### | ||
# Adding training support | ||
# ^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# Use ``register_autograd`` to add an autograd formula for the ``triton_op``. | ||
# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns | ||
# with ``torch.compile``). | ||
|
||
def backward(ctx, grad_output): | ||
x, = ctx.saved_tensors | ||
return grad_input * x.cos() | ||
|
||
def setup_context(ctx, inputs, output): | ||
x, = inputs | ||
ctx.save_for_backward(x) | ||
|
||
mysin.register_autograd(backward, setup_context=setup_context) | ||
|
||
###################################################################### | ||
# Note that the backward must be a composition of PyTorch-understood operators. | ||
# If you want the backward to call triton kernels, then those must be wrapped in ``triton_op`` as well: | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@triton.jit | ||
def cos_kernel( | ||
in_ptr0, | ||
out_ptr, | ||
n_elements, | ||
BLOCK_SIZE: "tl.constexpr", | ||
): | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(in_ptr0 + offsets, mask=mask) | ||
output = tl.cos(x) | ||
tl.store(out_ptr + offsets, output, mask=mask) | ||
|
||
@triton_op("mylib::mycos", mutates_args={}) | ||
def mycos(x: torch.Tensor) -> torch.Tensor: | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
|
||
def backward(ctx, grad_output): | ||
x, = ctx.saved_tensors | ||
return grad_input * mycos(x) | ||
|
||
def setup_context(ctx, inputs, output): | ||
x, = inputs | ||
ctx.save_for_backward(x) | ||
|
||
mysin.register_autograd(backward, setup_context=setup_context) | ||
|
||
###################################################################### | ||
# Adding a CPU Fallback | ||
# ^^^^^^^^^^^^^^^^^^^^^ | ||
# triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``: | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@mysin.register_kernel("cpu") | ||
def _(x): | ||
return torch.sin(x) | ||
|
||
x = torch.randn(3) | ||
y = mysin(x) | ||
assert torch.allclose(y, x.sin()) | ||
|
||
###################################################################### | ||
# The fallback must be composed of PyTorch operators. | ||
|
||
###################################################################### | ||
# Adding a FlopCounter Formula | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# To specify how many flops the triton kernel reports under PyTorch's flop counter, | ||
# use ``register_flop_formula``. | ||
|
||
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula | ||
|
||
@register_flop_formula(torch.ops.mylib.mysin) | ||
def _(x_shape): | ||
numel = 1 | ||
for s in x_shape: | ||
numel *= s | ||
return numel | ||
|
||
x = torch.randn(3, device="cuda") | ||
|
||
# NB: FlopCounterMode requires tabulate. | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# >>> with FlopCounterMode() as flop_counter: | ||
# >>> y = mysin(x) | ||
|
||
###################################################################### | ||
# Limitations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having this title be limitations, and then having two paragraphs of not-limitations is a bit weird |
||
# -------------------------------------------------------------------- | ||
# | ||
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile`` | ||
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor. | ||
# You can use these features together to build complex, high-performance models. | ||
# | ||
# PyTorch 2.6 added ``torch.library.triton_op``, which adds support for | ||
# user-defined Triton kernels in tensor subclasses and other advanced features. | ||
# | ||
# However, there are certain limitations to be aware of: | ||
# | ||
# * **Tensor Subclasses:** Currently, there is no support for | ||
# tensor subclasses and other advanced features. | ||
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or | ||
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This | ||
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used | ||
|
Uh oh!
There was an error while loading. Please reload this page.