Skip to content

Commit 3b21a67

Browse files
authored
[AMD][MLA] Fix mla autotune for rocm (#861)
* Refactor matmul example to include ReLU activation and update batch size in benchmark script * lint fix * Enhance autotuning capabilities in benchmark script and update argument defaults - Introduced a new `get_configs` function to generate autotuning configurations for the benchmark. - Updated the default batch size and kv context length in the argument parser for improved performance. - Renamed the `--auto_tune` argument to `--autotune` for consistency. - Modified the kernel invocation logic to support autotuning based on the new configurations. * lint fix
1 parent b9a51c4 commit 3b21a67

File tree

2 files changed

+50
-43
lines changed

2 files changed

+50
-43
lines changed

examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,31 @@
11
import torch
22
import torch.nn.functional as F
33
import tilelang
4-
from tilelang.autotuner import *
54
import tilelang.language as T
65
from einops import rearrange, einsum
76
import argparse
87

98
tilelang.disable_cache()
109

1110

11+
def get_configs():
12+
import itertools
13+
BLOCK_N = [16, 32, 64, 128]
14+
BLOCK_H = [16, 32, 64, 128]
15+
num_split = [1, 2, 4, 8, 16, 32]
16+
threads = [128, 256]
17+
18+
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads))
19+
20+
return [{
21+
"block_N": c[0],
22+
"block_H": c[1],
23+
"num_split": c[2],
24+
"threads": c[3],
25+
} for c in _configs]
26+
27+
28+
@tilelang.autotune(configs=get_configs())
1229
@tilelang.jit(
1330
out_idx=[6], pass_configs={
1431
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
@@ -273,26 +290,39 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
273290

274291
if __name__ == "__main__":
275292
parser = argparse.ArgumentParser()
276-
parser.add_argument('--batch', type=int, default=1, help='batch size')
293+
parser.add_argument('--batch', type=int, default=128, help='batch size')
277294
parser.add_argument('--heads', type=int, default=128, help='q heads number')
278295
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
279-
parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length')
296+
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
280297
parser.add_argument('--dim', type=int, default=512, help='head dim')
281298
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
282-
parser.add_argument('--auto_tune', action='store_true', help='auto tune')
299+
parser.add_argument('--autotune', action='store_true', help='auto tune')
283300
args = parser.parse_args()
284301
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
285-
enable_autotune = args.auto_tune
302+
enable_autotune = args.autotune
286303

287304
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
288305
pv_flops = 2 * batch * heads * kv_ctx * dim
289306
total_flops = qk_flops + pv_flops
290307
BLOCK_N = 32
291308
BLOCK_H = 64
292309
num_split = 4
310+
threads = 128
293311

294-
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H,
295-
num_split)
312+
if enable_autotune:
313+
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
314+
else:
315+
kernel = flashmla_decode(
316+
batch,
317+
heads,
318+
kv_heads,
319+
kv_ctx,
320+
dim,
321+
pe_dim,
322+
BLOCK_N,
323+
BLOCK_H,
324+
num_split,
325+
threads=threads)
296326
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
297327
input_tensors = profiler._get_inputs()
298328
tilelang_output = kernel(*input_tensors)
@@ -303,35 +333,3 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
303333
latency = profiler.do_bench(warmup=500)
304334
print(f"Latency: {latency} ms")
305335
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
306-
307-
# Enable Auto Tuning
308-
309-
310-
def get_configs():
311-
import itertools
312-
BLOCK_N = [16, 32, 64, 128]
313-
BLOCK_H = [16, 32, 64, 128]
314-
num_split = [1, 2, 4, 8, 16, 32]
315-
thread_num = [128, 256]
316-
317-
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, thread_num))
318-
319-
return [{
320-
"block_N": c[0],
321-
"block_H": c[1],
322-
"num_split": c[2],
323-
"thread_num": c[3],
324-
} for c in _configs]
325-
326-
def wrapped_kernel(block_N=None, block_H=None, num_split=None, thread_num=None):
327-
return flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, block_N, block_H,
328-
num_split, thread_num)
329-
330-
if enable_autotune:
331-
autotuner = AutoTuner.from_kernel(kernel=wrapped_kernel, configs=get_configs())
332-
tune_result = autotuner.run(warmup=3, rep=20)
333-
best_latency = tune_result.latency
334-
best_config = tune_result.config
335-
print(f"Best latency: {best_latency} ms")
336-
print(f"Best TFlops: {total_flops / best_latency * 1e-9} TFlops")
337-
print(f"Best config: {best_config}")

tilelang/autotuner/tuner.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class AutoTuner:
104104
profile_args = ProfileArgs()
105105

106106
_kernel_parameters: Optional[Tuple[str, ...]] = None
107+
_function_parameters: Optional[Dict[str, Any]] = None
107108
_lock = threading.Lock() # For thread safety
108109
_memory_cache = {} # In-memory cache dictionary
109110
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
@@ -222,9 +223,10 @@ def set_profile_args(self,
222223

223224
return self
224225

225-
def set_kernel_parameters(self, parameters: Tuple[str, ...]):
226+
def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]):
226227
# for cache key generation
227-
self._kernel_parameters = parameters
228+
self._kernel_parameters = k_parameters
229+
self._function_parameters = f_parameters
228230

229231
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
230232
"""Generate a cache key for the auto-tuning process.
@@ -417,8 +419,15 @@ def shape_equal(a, b):
417419
key_args_tuple, key_kwargs_tuple = self._kernel_parameters
418420
tunable_arguments = [key for key, _ in top_config.items()]
419421

422+
def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool:
423+
params_list = list(parameters.keys())
424+
assert key in params_list, f"Tunable argument {key} not found in function parameters"
425+
return params_list.index(key) < len(key_args_tuple)
426+
420427
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
421-
if any(key in top_config for key, _ in key_kwargs_tuple):
428+
if any(key in top_config for key, _ in key_kwargs_tuple) or any(
429+
check_tunable_argument_value(key, self._function_parameters, key_args_tuple)
430+
for key in tunable_arguments):
422431
logger.warning(
423432
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
424433
)
@@ -676,7 +685,7 @@ def jit_compile(**config_arg):
676685
)
677686

678687
autotuner.jit_compile = jit_compile
679-
autotuner.set_kernel_parameters(key)
688+
autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters)
680689

681690
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
682691

0 commit comments

Comments
 (0)