Skip to content

Commit b9a51c4

Browse files
authored
[TMA] Bugfix when a shared buffer is both issued with tma store and tma load (#857)
- Updated `init_desc_arg_map` to use `Var` as the key instead of `String` in `lower_hopper_intrin.cc`. - Enhanced `func_call_args` method in `TLCUDASourceWrapper` to accept additional parameters for better argument mapping. - Added assertions to ensure consistency between function parameters and arguments during kernel launches. - Modified `generate_tma_descriptor_args` to utilize a mapping of variable names for TMA descriptor initialization.
1 parent 058a670 commit b9a51c4

File tree

2 files changed

+74
-13
lines changed

2 files changed

+74
-13
lines changed

src/transform/lower_hopper_intrin.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class LowerHopperIntrin : public StmtExprMutator {
2525
PrimFuncNode *fptr = f.CopyOnWrite();
2626
LowerHopperIntrin substituter(disable_shuffle_elect);
2727
fptr->body = substituter.VisitStmt(f->body);
28-
Map<String, Array<PrimExpr>> init_desc_arg_map;
28+
Map<Var, Array<PrimExpr>> init_desc_arg_map;
2929
for (const auto &[call, var] : substituter.desc_map_) {
3030
// Should allocate 128 bytes for TensorMap on stack
3131
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
@@ -46,7 +46,7 @@ class LowerHopperIntrin : public StmtExprMutator {
4646
Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
4747
fptr->body =
4848
LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
49-
init_desc_arg_map.Set(var->name_hint, init_desc_args);
49+
init_desc_arg_map.Set(var, init_desc_args);
5050
}
5151
f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map);
5252
return f;

tilelang/jit/adapter/wrapper.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import logging
1010
import textwrap
11+
from tvm.tir.stmt_functor import post_order_visit
1112

1213
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY = """
1314
cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
@@ -260,7 +261,11 @@ def create_dispatch_func(self, code, function_informations):
260261
# Format the function arguments for declaration
261262
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
262263

263-
def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None):
264+
def func_call_args(s,
265+
function_args,
266+
function_params,
267+
desc_name_map: Optional[Dict[str, str]] = None,
268+
desc_name_var_map: Optional[Dict[str, tvm.tir.Var]] = None):
264269
# Extract the function call arguments matching the function definition
265270
def maybe_desc(name: str, matches: List[str], i: int):
266271
match = matches[i]
@@ -280,8 +285,15 @@ def maybe_desc(name: str, matches: List[str], i: int):
280285
call_args = []
281286
for i, match in enumerate(matches):
282287
for arg in function_args:
283-
if arg["name"] == match or maybe_desc(arg["name"], matches, i):
288+
if arg["name"] == match:
284289
call_args.append(match)
290+
elif maybe_desc(arg["name"], matches, i):
291+
call_args.append(match)
292+
assert len(call_args) <= len(
293+
function_params
294+
), f"Function {function_name} has {len(function_params)} parameters, but {len(call_args)} arguments"
295+
desc_name_var_map[match] = function_params[len(call_args) - 1]
296+
285297
return call_args
286298

287299
has_l2_persistent_map = False
@@ -294,10 +306,12 @@ def maybe_desc(name: str, matches: List[str], i: int):
294306
if has_l2_persistent_map:
295307
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
296308
desc_name_map: Dict[str, str] = {}
309+
desc_name_var_map: Dict[str, tvm.tir.Var] = {}
297310
for function_name, function_info in function_informations.items():
298311
block_info = function_info["block_info"]
299312
grid_info = function_info["grid_info"]
300313
dynamic_smem_buf = function_info["dynamic_smem_buf"]
314+
function_params = function_info["function_params"]
301315

302316
# Find the location of the global kernel function in the code
303317
index = match_declare_kernel(code, function_name + "(")
@@ -321,22 +335,32 @@ def maybe_desc(name: str, matches: List[str], i: int):
321335
kernel_launch_code += init_l2_persistent_map
322336

323337
if self.use_cooperative_groups[function_name]:
324-
args_list = func_call_args(declaration, function_args, desc_name_map)
338+
args_list = func_call_args(declaration, function_args, function_params,
339+
desc_name_map, desc_name_var_map)
340+
assert len(function_params) == len(
341+
args_list
342+
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
325343
args_array = [f"(void*)&{arg}" for arg in args_list]
326344
call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n"
327345
kernel_launch_code += call_args
328346
# Using cudaLaunchCooperativeKernel to launch the kernel
329347
kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
330348
function_name, grid_str, block_str, function_name + "_args", smem_str)
331349
else:
332-
call_args = ", ".join(func_call_args(declaration, function_args, desc_name_map))
350+
args_list = func_call_args(declaration, function_args, function_params,
351+
desc_name_map, desc_name_var_map)
352+
assert len(function_params) == len(
353+
args_list
354+
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
355+
call_args = ", ".join(args_list)
333356
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
334357
function_name, grid_str, block_str, smem_str, call_args)
335358
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
336359
if has_l2_persistent_map:
337360
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
338361

339-
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map)
362+
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map,
363+
desc_name_var_map)
340364
kernel_launch_code = init_tma_descriptor_args + kernel_launch_code
341365

342366
# Wrap the kernel dispatch logic in an external C function
@@ -362,15 +386,17 @@ def generate_l2_persistent_map(self, function_name: str) -> str:
362386

363387
return init_l2_persistent_map
364388

365-
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
389+
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str],
390+
desc_name_var_map: Dict[str, tvm.tir.Var]) -> str:
366391
tma_descripter_init = ""
367392
if self.tma_descriptor_args is None:
368393
return tma_descripter_init
394+
for handle_name, _ in desc_name_map.items():
395+
assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map"
396+
desc_var = desc_name_var_map[handle_name]
369397

370-
for handle_name, name in desc_name_map.items():
371-
desc_name = name + "_desc"
372-
assert desc_name in self.tma_descriptor_args, f"TMA descriptor {desc_name} not found in {self.tma_descriptor_args}"
373-
args = self.tma_descriptor_args[desc_name]
398+
assert desc_var in self.tma_descriptor_args, f"TMA descriptor {desc_var} not found in {self.tma_descriptor_args}"
399+
args = self.tma_descriptor_args[desc_var]
374400
# Skip __tvm_tensormap_create_tiled
375401
if len(args) < 3:
376402
raise ValueError(
@@ -536,12 +562,35 @@ def update_lib_code(self, code: str):
536562
# Do not update function with dispatch host function
537563
if (function_name not in self.block_info) or (function_name not in self.grid_info):
538564
continue
565+
assert function_name in self.device_mod, f"Function {function_name} not found in device module"
566+
device_func = self.device_mod[function_name]
567+
kernel_params_cnt = len(device_func.params)
568+
function_params: List[str] = None
569+
570+
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
571+
nonlocal function_params
572+
if isinstance(node, tvm.tir.Call):
573+
if not (hasattr(node, "op") and
574+
node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
575+
return
576+
args = node.args
577+
if not args or args[0] != fn:
578+
return
579+
if len(args) < 1 + param_cnt:
580+
raise AssertionError(
581+
"tvm_call_packed should have at least 1 argument and match device function parameters"
582+
)
583+
function_params = args[1:1 + param_cnt]
584+
585+
post_order_visit(self.host_func.body, visitor)
586+
assert function_params is not None, "function_params should not be None"
539587

540588
function_informations[function_name] = {
541589
"function_name": function_name,
542590
"block_info": self.block_info[function_name],
543591
"grid_info": self.grid_info[function_name],
544592
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
593+
"function_params": function_params,
545594
}
546595

547596
# Create the host function wrapper for the CUDA kernel
@@ -579,6 +628,19 @@ def device_func(self):
579628
return function
580629
raise ValueError("Cannot find primary function in the module.")
581630

631+
@property
632+
def host_func(self):
633+
if len(self.host_mod.get_global_vars()) == 1:
634+
return self.host_mod[self.host_mod.get_global_vars()[0]]
635+
elif "main" in self.host_mod:
636+
return self.host_mod["main"]
637+
else:
638+
for _, function in self.host_mod.functions.items():
639+
attr = function.attrs
640+
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
641+
return function
642+
raise ValueError("Cannot find primary function in the module.")
643+
582644

583645
class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
584646
"""
@@ -636,7 +698,6 @@ def create_dispatch_func(self, code, function_informations):
636698
function_args.append({"name": dyn_sym, "type": "ctypes.c_int"})
637699

638700
function_args.append(self.get_stream_type())
639-
640701
# Format the function arguments for declaration
641702
def_args = ", ".join([f"{arg['name']}" for arg in function_args])
642703

0 commit comments

Comments
 (0)