Commit ad1dc65
authored
[FRONTEND] Allow JITFunctions as arguments to other JITFunctions (#5723)
This PR allows a call to a JITFunction to pass another JITFunction as an
argument.
For example:
```python
@triton.jit
def fn_a(x):
...
@triton.jit
def fn_b(x, fn):
...
@triton.jit
def fn_c(x):
return fn_b(x, fn_a) # fn_a (a JITFunction) is passed as an argument to fn_b (another JITFunction)
```
Prior to #5220, this worked. After #5220, the user needs to annotate the
JITFunctions with @triton.constexpr manually (until this PR).
Use case: Inductor has some generic helper functions for implementing
scans (e.g. exclusive_scan_decoupled_lookback) which take a `combine_fn`
to implement the combination function (similar to tl.reduce). These
helper functions have stopped working after #5220.
https://github.com/pytorch/pytorch/blob/01a4d86b31365cfb484dc17885c9a7ee09c235ab/torch/_inductor/runtime/triton_helpers.py#L3211 parent b55477a commit ad1dc65
File tree
2 files changed
+30
-1
lines changed- python
- test/unit/language
- triton/compiler
2 files changed
+30
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6833 | 6833 | | |
6834 | 6834 | | |
6835 | 6835 | | |
| 6836 | + | |
| 6837 | + | |
| 6838 | + | |
| 6839 | + | |
| 6840 | + | |
| 6841 | + | |
| 6842 | + | |
| 6843 | + | |
| 6844 | + | |
| 6845 | + | |
| 6846 | + | |
| 6847 | + | |
| 6848 | + | |
| 6849 | + | |
| 6850 | + | |
| 6851 | + | |
| 6852 | + | |
| 6853 | + | |
| 6854 | + | |
| 6855 | + | |
| 6856 | + | |
| 6857 | + | |
| 6858 | + | |
| 6859 | + | |
| 6860 | + | |
| 6861 | + | |
| 6862 | + | |
| 6863 | + | |
| 6864 | + | |
6836 | 6865 | | |
6837 | 6866 | | |
6838 | 6867 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1153 | 1153 | | |
1154 | 1154 | | |
1155 | 1155 | | |
1156 | | - | |
| 1156 | + | |
1157 | 1157 | | |
1158 | 1158 | | |
1159 | 1159 | | |
| |||
0 commit comments