8
8
import re
9
9
import logging
10
10
import textwrap
11
+ from tvm .tir .stmt_functor import post_order_visit
11
12
12
13
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY = """
13
14
cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
@@ -260,7 +261,11 @@ def create_dispatch_func(self, code, function_informations):
260
261
# Format the function arguments for declaration
261
262
def_args = ", " .join ([f"{ arg ['type' ]} { arg ['name' ]} " for arg in function_args ])
262
263
263
- def func_call_args (s , function_args , desc_name_map : Optional [Dict [str , str ]] = None ):
264
+ def func_call_args (s ,
265
+ function_args ,
266
+ function_params ,
267
+ desc_name_map : Optional [Dict [str , str ]] = None ,
268
+ desc_name_var_map : Optional [Dict [str , tvm .tir .Var ]] = None ):
264
269
# Extract the function call arguments matching the function definition
265
270
def maybe_desc (name : str , matches : List [str ], i : int ):
266
271
match = matches [i ]
@@ -280,8 +285,15 @@ def maybe_desc(name: str, matches: List[str], i: int):
280
285
call_args = []
281
286
for i , match in enumerate (matches ):
282
287
for arg in function_args :
283
- if arg ["name" ] == match or maybe_desc ( arg [ "name" ], matches , i ) :
288
+ if arg ["name" ] == match :
284
289
call_args .append (match )
290
+ elif maybe_desc (arg ["name" ], matches , i ):
291
+ call_args .append (match )
292
+ assert len (call_args ) <= len (
293
+ function_params
294
+ ), f"Function { function_name } has { len (function_params )} parameters, but { len (call_args )} arguments"
295
+ desc_name_var_map [match ] = function_params [len (call_args ) - 1 ]
296
+
285
297
return call_args
286
298
287
299
has_l2_persistent_map = False
@@ -294,10 +306,12 @@ def maybe_desc(name: str, matches: List[str], i: int):
294
306
if has_l2_persistent_map :
295
307
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
296
308
desc_name_map : Dict [str , str ] = {}
309
+ desc_name_var_map : Dict [str , tvm .tir .Var ] = {}
297
310
for function_name , function_info in function_informations .items ():
298
311
block_info = function_info ["block_info" ]
299
312
grid_info = function_info ["grid_info" ]
300
313
dynamic_smem_buf = function_info ["dynamic_smem_buf" ]
314
+ function_params = function_info ["function_params" ]
301
315
302
316
# Find the location of the global kernel function in the code
303
317
index = match_declare_kernel (code , function_name + "(" )
@@ -321,22 +335,32 @@ def maybe_desc(name: str, matches: List[str], i: int):
321
335
kernel_launch_code += init_l2_persistent_map
322
336
323
337
if self .use_cooperative_groups [function_name ]:
324
- args_list = func_call_args (declaration , function_args , desc_name_map )
338
+ args_list = func_call_args (declaration , function_args , function_params ,
339
+ desc_name_map , desc_name_var_map )
340
+ assert len (function_params ) == len (
341
+ args_list
342
+ ), f"Function { function_name } has { len (function_params )} parameters, but { len (args_list )} arguments"
325
343
args_array = [f"(void*)&{ arg } " for arg in args_list ]
326
344
call_args = f"\t void* { function_name } _args[] = {{{ ', ' .join (args_array )} }};\n "
327
345
kernel_launch_code += call_args
328
346
# Using cudaLaunchCooperativeKernel to launch the kernel
329
347
kernel_launch_code += "\t TILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n " .format (
330
348
function_name , grid_str , block_str , function_name + "_args" , smem_str )
331
349
else :
332
- call_args = ", " .join (func_call_args (declaration , function_args , desc_name_map ))
350
+ args_list = func_call_args (declaration , function_args , function_params ,
351
+ desc_name_map , desc_name_var_map )
352
+ assert len (function_params ) == len (
353
+ args_list
354
+ ), f"Function { function_name } has { len (function_params )} parameters, but { len (args_list )} arguments"
355
+ call_args = ", " .join (args_list )
333
356
kernel_launch_code += "\t {}<<<{}, {}, {}, stream>>>({});\n " .format (
334
357
function_name , grid_str , block_str , smem_str , call_args )
335
358
kernel_launch_code += "\t TILELANG_CHECK_LAST_ERROR(\" {}\" );\n " .format (function_name )
336
359
if has_l2_persistent_map :
337
360
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
338
361
339
- init_tma_descriptor_args = self .generate_tma_descriptor_args (desc_name_map )
362
+ init_tma_descriptor_args = self .generate_tma_descriptor_args (desc_name_map ,
363
+ desc_name_var_map )
340
364
kernel_launch_code = init_tma_descriptor_args + kernel_launch_code
341
365
342
366
# Wrap the kernel dispatch logic in an external C function
@@ -362,15 +386,17 @@ def generate_l2_persistent_map(self, function_name: str) -> str:
362
386
363
387
return init_l2_persistent_map
364
388
365
- def generate_tma_descriptor_args (self , desc_name_map : Dict [str , str ]) -> str :
389
+ def generate_tma_descriptor_args (self , desc_name_map : Dict [str , str ],
390
+ desc_name_var_map : Dict [str , tvm .tir .Var ]) -> str :
366
391
tma_descripter_init = ""
367
392
if self .tma_descriptor_args is None :
368
393
return tma_descripter_init
394
+ for handle_name , _ in desc_name_map .items ():
395
+ assert handle_name in desc_name_var_map , f"Handle name { handle_name } not found in desc_name_var_map"
396
+ desc_var = desc_name_var_map [handle_name ]
369
397
370
- for handle_name , name in desc_name_map .items ():
371
- desc_name = name + "_desc"
372
- assert desc_name in self .tma_descriptor_args , f"TMA descriptor { desc_name } not found in { self .tma_descriptor_args } "
373
- args = self .tma_descriptor_args [desc_name ]
398
+ assert desc_var in self .tma_descriptor_args , f"TMA descriptor { desc_var } not found in { self .tma_descriptor_args } "
399
+ args = self .tma_descriptor_args [desc_var ]
374
400
# Skip __tvm_tensormap_create_tiled
375
401
if len (args ) < 3 :
376
402
raise ValueError (
@@ -536,12 +562,35 @@ def update_lib_code(self, code: str):
536
562
# Do not update function with dispatch host function
537
563
if (function_name not in self .block_info ) or (function_name not in self .grid_info ):
538
564
continue
565
+ assert function_name in self .device_mod , f"Function { function_name } not found in device module"
566
+ device_func = self .device_mod [function_name ]
567
+ kernel_params_cnt = len (device_func .params )
568
+ function_params : List [str ] = None
569
+
570
+ def visitor (node , fn = function_name , param_cnt = kernel_params_cnt ):
571
+ nonlocal function_params
572
+ if isinstance (node , tvm .tir .Call ):
573
+ if not (hasattr (node , "op" ) and
574
+ node .op == tvm .ir .Op .get ("tir.tvm_call_packed" )):
575
+ return
576
+ args = node .args
577
+ if not args or args [0 ] != fn :
578
+ return
579
+ if len (args ) < 1 + param_cnt :
580
+ raise AssertionError (
581
+ "tvm_call_packed should have at least 1 argument and match device function parameters"
582
+ )
583
+ function_params = args [1 :1 + param_cnt ]
584
+
585
+ post_order_visit (self .host_func .body , visitor )
586
+ assert function_params is not None , "function_params should not be None"
539
587
540
588
function_informations [function_name ] = {
541
589
"function_name" : function_name ,
542
590
"block_info" : self .block_info [function_name ],
543
591
"grid_info" : self .grid_info [function_name ],
544
592
"dynamic_smem_buf" : self .dynamic_smem_buf [function_name ],
593
+ "function_params" : function_params ,
545
594
}
546
595
547
596
# Create the host function wrapper for the CUDA kernel
@@ -579,6 +628,19 @@ def device_func(self):
579
628
return function
580
629
raise ValueError ("Cannot find primary function in the module." )
581
630
631
+ @property
632
+ def host_func (self ):
633
+ if len (self .host_mod .get_global_vars ()) == 1 :
634
+ return self .host_mod [self .host_mod .get_global_vars ()[0 ]]
635
+ elif "main" in self .host_mod :
636
+ return self .host_mod ["main" ]
637
+ else :
638
+ for _ , function in self .host_mod .functions .items ():
639
+ attr = function .attrs
640
+ if "tir.is_global_func" in attr and attr ["tir.is_global_func" ]:
641
+ return function
642
+ raise ValueError ("Cannot find primary function in the module." )
643
+
582
644
583
645
class TLNVRTCSourceWrapper (TLCUDASourceWrapper ):
584
646
"""
@@ -636,7 +698,6 @@ def create_dispatch_func(self, code, function_informations):
636
698
function_args .append ({"name" : dyn_sym , "type" : "ctypes.c_int" })
637
699
638
700
function_args .append (self .get_stream_type ())
639
-
640
701
# Format the function arguments for declaration
641
702
def_args = ", " .join ([f"{ arg ['name' ]} " for arg in function_args ])
642
703
0 commit comments