Skip to content

Commit cd1cc2d

Browse files
authored
Refactor compiler specializations to consider backend (#4734)
In this PR I am trying to refactor the specializations that we apply to the signature of a given function in Triton. Basically, given a kernel there are some argument properties that can help compilation. E.g., divisibility by 16 and the fact that an integer is equal to 1. In a previous PR: #4716, I needed other specializations to add buffer support in the AMD backend (and get back some performance when we were using unaligned masked loads). So this is my attempt to redesign the specialization support to introduce per-backend specializations. The idea is that `AttrsDescriptor` is now the class that is taking care of doing the analysis of the parameters and adding the specialization. It also has a function table where more specializations can be added per-backend.
1 parent 057a9d3 commit cd1cc2d

File tree

8 files changed

+229
-92
lines changed

8 files changed

+229
-92
lines changed

python/test/unit/runtime/test_bindings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def walk_fn(op):
5959
torch.empty((32, 32), device=device), # out_ptr
6060
16, # BLOCK_SIZE
6161
]
62+
target = triton.runtime.driver.active.get_current_target()
63+
backend = triton.compiler.compiler.make_backend(target)
6264
src = triton.compiler.compiler.ASTSource(
6365
fn=kernel,
6466
signature={
@@ -69,12 +71,10 @@ def walk_fn(op):
6971
constants={kernel.arg_names[i]: arg
7072
for i, arg in enumerate(args)
7173
if not isinstance(arg, torch.Tensor)},
72-
attrs=kernel._get_config(*args, ),
74+
attrs=backend.get_attrs_descriptor(args, kernel.params),
7375
)
7476

7577
context = triton._C.libtriton.ir.context()
76-
target = triton.runtime.driver.active.get_current_target()
77-
backend = triton.compiler.compiler.make_backend(target)
7878
options = backend.parse_options(dict())
7979
codegen_fns = dict()
8080
module_map = backend.get_module_map()

python/test/unit/runtime/test_subproc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import triton
55
import triton.language as tl
6+
from triton.backends.compiler import AttrsDescriptor
67
from triton.compiler import ASTSource
78

89
target = triton.runtime.driver.active.get_current_target()
@@ -25,7 +26,7 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2526

2627

2728
def test_compile_in_subproc() -> None:
28-
config = triton.compiler.AttrsDescriptor(tuple(range(4)), ())
29+
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
2930
multiprocessing.set_start_method('fork')
3031
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
3132
proc.start()
@@ -47,7 +48,7 @@ def kernel_dot(Z):
4748

4849

4950
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
50-
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
51+
config = AttrsDescriptor.from_hints({0: 16})
5152
assert multiprocessing.get_start_method() == 'fork'
5253
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, ))
5354
proc.start()
@@ -86,7 +87,7 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
8687
gc.disable()
8788

8889
# stage 1.p
89-
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
90+
config = AttrsDescriptor.from_hints({0: 16})
9091
compile_empty_kernel_with_gc(config)
9192

9293
# stage 2.p

python/triton/backends/compiler.py

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import re
3+
import hashlib
34
import subprocess
45

56
from abc import ABCMeta, abstractmethod, abstractclassmethod
@@ -8,6 +9,184 @@
89
from types import ModuleType
910

1011

12+
class AttrsDescriptor:
13+
"""
14+
This class handles compile-time properties for specific function parameters.
15+
16+
Different backends can add more properties to the common ones. The class
17+
contains two fields:
18+
19+
`arg_properties`: a dictionary containing the different compile-time properties for different
20+
parameters. I.e., the dictionary is a map from property names to parameter indices
21+
{
22+
"prop0": (0, 2, 3)
23+
"prop1": (0, 4, 5)
24+
}
25+
Different backends might need different properties on those paraemters to enable
26+
specific optimizations. The common compile time properties contained in this class
27+
are :
28+
- "tt.divisibility", i.e., is the given parameter divisible by 16
29+
- "tt.equal_to_1", i.e., is the given parameter an integer constant 1
30+
31+
`property_values`: a dictionary containing the value of the different compile-time properties, like:
32+
{
33+
"prop0": val0
34+
"prop1": val1
35+
}
36+
37+
`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
38+
39+
"""
40+
__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')
41+
42+
def __init__(self, params=None, values=None):
43+
"""
44+
Initialize the compile-time properties
45+
46+
We can initialize the AttrsDescriptor class by passing the list of params
47+
of the function and their `values`. The function will try to apply the properties
48+
to the values and save the parameters in the `arg_properties` list. If we don't pass
49+
either the `params` or the `values` we should initialize the class via an alternative method
50+
(see `from_dict` or `from_hints`)
51+
"""
52+
# Default initialization
53+
self.arg_properties = {}
54+
self.property_values = {}
55+
self.constant_properties = set()
56+
57+
self._add_common_properties(params, values)
58+
self._add_backend_properties(params, values)
59+
self._init_slots()
60+
61+
def _add_common_properties(self, params, values):
62+
""" Add common compile-time properties """
63+
self.property_values["tt.divisibility"] = 16
64+
self.property_values["tt.equal_to"] = 1
65+
self.constant_properties.add("tt.equal_to")
66+
67+
if (params is None) or (values is None):
68+
return
69+
70+
# Compile properties deduction
71+
assert (len(params) == len(values))
72+
73+
# Divisibility property
74+
self.arg_properties["tt.divisibility"] = [
75+
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
76+
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
77+
]
78+
79+
# Equal to 1 property
80+
self.arg_properties["tt.equal_to"] = [
81+
param.num
82+
for param, arg in zip(params, values)
83+
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
84+
]
85+
86+
def _add_backend_properties(self, params=None, values=None):
87+
""" This method is for different subclasses to implement their own compile-time properties """
88+
pass
89+
90+
def _init_slots(self):
91+
""" Initialize the slots of this class """
92+
for name, val in self.arg_properties.items():
93+
setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)
94+
95+
def get_fn_attrs(self) -> Dict:
96+
"""
97+
Get the function attributes as a dictionary.
98+
99+
The returned dictionary will look like :
100+
{
101+
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
102+
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
103+
}
104+
"""
105+
attrs = {}
106+
for prop_name, arg_set in self.arg_properties.items():
107+
prop_val = self.property_values[prop_name]
108+
for arg in arg_set:
109+
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
110+
return attrs
111+
112+
def get_constants(self) -> Dict:
113+
""" Return a mapping of constant parameters to their values """
114+
constants = {}
115+
for prop_name in self.constant_properties:
116+
for p in self.arg_properties.get(prop_name, []):
117+
constants[p] = self.property_values[prop_name]
118+
return constants
119+
120+
def filter_out_constants(self):
121+
""" Return the same object, without properties marked as constants"""
122+
import copy
123+
c = copy.deepcopy(self)
124+
for prop_name in c.constant_properties:
125+
c.arg_properties.pop(prop_name, None)
126+
c.property_values.pop(prop_name, None)
127+
c.constant_properties = {}
128+
return c
129+
130+
def hash(self):
131+
values = [sorted(self.arg_properties.values())]
132+
values += [sorted(self.property_values.values())]
133+
values += [sorted(self.constant_properties)]
134+
key = str(values)
135+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
136+
137+
def to_dict(self):
138+
return self.arg_properties
139+
140+
@staticmethod
141+
def from_dict(data):
142+
attrsDescriptor = AttrsDescriptor()
143+
for prop_name, param_ids in data.items():
144+
attrsDescriptor.arg_properties[prop_name] = param_ids
145+
attrsDescriptor._init_slots()
146+
return attrsDescriptor
147+
148+
@staticmethod
149+
def from_hints(hints: list[tuple[int, int]]):
150+
"""
151+
Create the class from a set of hints that are passed in.
152+
153+
Instead of deducing the properties from a list of paramaters and values,
154+
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
155+
matches one of the values of the properties (e.g., `prop_val[prop0]`),
156+
then we insert `param_index` into the correct list (e.g., in
157+
`arg_properties[prop0]`)
158+
"""
159+
attrsDescriptor = AttrsDescriptor()
160+
for prop_name, prop_val in attrsDescriptor.property_values.items():
161+
attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
162+
attrsDescriptor._init_slots()
163+
return attrsDescriptor
164+
165+
@staticmethod
166+
def is_divisible_by_16(x):
167+
""" Return if the argument is a multiple of 16"""
168+
if hasattr(x, "data_ptr"):
169+
return x.data_ptr() % 16 == 0
170+
elif isinstance(x, int):
171+
return x % 16 == 0
172+
if x is None:
173+
return True
174+
return False
175+
176+
@staticmethod
177+
def is_equal_to_1(x):
178+
""" Return if the argument is a constant 1"""
179+
return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False
180+
181+
@staticmethod
182+
def get_property_key(val, align):
183+
if align and AttrsDescriptor.is_divisible_by_16(val):
184+
return "D"
185+
if AttrsDescriptor.is_equal_to_1(val):
186+
return "1"
187+
return "N"
188+
189+
11190
@dataclass(frozen=True)
12191
class GPUTarget(object):
13192
# Target backend, e.g., cuda, hip
@@ -79,6 +258,20 @@ def load_dialects(self, context):
79258
@abstractmethod
80259
def get_module_map(self) -> Dict[str, ModuleType]:
81260
"""
82-
Return a map of interface modules to their device-specific implementations.
261+
Return a map of interface modules to their device-specific implementations
83262
"""
84263
raise NotImplementedError
264+
265+
def get_attrs_descriptor(self, params, args):
266+
"""
267+
Return an attribute descriptor: given a set of parameters and arguments
268+
the descriptor stores a set of compile time properties that can improve code
269+
generation. Different backends might benefit from different properties
270+
"""
271+
return AttrsDescriptor(params, args)
272+
273+
def compute_spec_key(self, arg, align):
274+
"""
275+
Return the ascii key for a given argument with a given set of properties
276+
"""
277+
return AttrsDescriptor.get_property_key(arg, align)

python/triton/compiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict
1+
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
22
from .errors import CompilationError
33

44
__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]

python/triton/compiler/code_generator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ def kernel_suffix(signature, specialization):
12701270
suffix += str(i)
12711271
if i in specialization.equal_to_1:
12721272
suffix += 'c'
1273-
if i in specialization.divisible_by_16:
1273+
if i in specialization.divisibility_16:
12741274
suffix += 'd'
12751275
return suffix
12761276

@@ -1284,17 +1284,21 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
12841284
gscope = fn.__globals__.copy()
12851285
function_name = fn.repr(specialization)
12861286
tys = list(specialization.signature.values())
1287-
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1}
1288-
new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16}
1287+
new_constants = attrs.get_constants()
1288+
for k in new_constants:
1289+
if k in tys and tys[k] == "i1" and new_constants[k] == 1:
1290+
new_constants[k] = True
12891291

1292+
new_attrs = attrs.filter_out_constants()
1293+
fn_attrs = new_attrs.get_fn_attrs()
12901294
all_constants = constants.copy()
12911295
all_constants.update(new_constants)
12921296
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
12931297
file_name, begin_line = get_jit_fn_file_line(fn)
12941298

12951299
prototype = language.function_type([], arg_types)
12961300
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
1297-
jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name,
1301+
jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name,
12981302
begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
12991303
generator.visit(fn.parse())
13001304

python/triton/compiler/compiler.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,19 @@
33
import json
44
from .._C.libtriton import get_cache_invalidating_env_vars, ir
55
from ..backends import backends
6-
from ..backends.compiler import GPUTarget
6+
from ..backends.compiler import GPUTarget, AttrsDescriptor
77
from .. import __version__
88
from ..runtime.autotuner import OutOfResources
99
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
1010
from ..runtime.driver import driver
1111
from ..tools.disasm import get_sass
1212
# TODO: this shouldn't be here
13-
from dataclasses import dataclass
1413
from .code_generator import ast_to_ttir
1514
from pathlib import Path
1615
import re
1716
import functools
1817
import os
1918

20-
21-
@dataclass
22-
class AttrsDescriptor:
23-
divisible_by_16: set = None
24-
equal_to_1: set = None
25-
26-
def __post_init__(self):
27-
if self.divisible_by_16 is None:
28-
self.divisible_by_16 = set()
29-
if self.equal_to_1 is None:
30-
self.equal_to_1 = set()
31-
32-
def to_dict(self):
33-
return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)}
34-
35-
@staticmethod
36-
def from_dict(data):
37-
return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])),
38-
equal_to_1=set(data.get('equal_to_1', [])))
39-
40-
def hash(self):
41-
key = str([sorted(x) for x in self.__dict__.values()])
42-
return hashlib.sha256(key.encode("utf-8")).hexdigest()
43-
44-
4519
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
4620
# and any following whitespace
4721
# - (public\s+)? : optionally match the keyword public and any following whitespace

0 commit comments

Comments
 (0)