@@ -3333,11 +3333,14 @@ def use_cubin_header(sm,
33333333 head_size ,
33343334 dtype ,
33353335 output_dtype = None ,
3336- enable_skip_softmax = False ):
3336+ enable_skip_softmax = False ,
3337+ attention_mask_type = None ):
33373338 if enable_skip_softmax :
33383339 return False
33393340 if 'e4m3' in dtype and output_dtype in ['bf16' , 'fp16' ]:
33403341 return False
3342+ if attention_mask_type == AttentionMaskType .BIDIRECTIONAL_SLIDING_WINDOW :
3343+ return False
33413344 return (sm == 90 and head_size == 128 ) or (sm == 89 and 'e4m3' in dtype )
33423345
33433346
@@ -3349,9 +3352,11 @@ def get_cubin_header(kernel_traits, specs_names):
33493352 cubin_lens_dict = {}
33503353 launchers_dict = {}
33513354 for kspec , fname , lname , kname in specs_names :
3355+ mask_type = AttentionMaskType .BIDIRECTIONAL_SLIDING_WINDOW \
3356+ if '_bidirectional_sliding_window' in kname else None
33523357 if generate_cu_trtllm and not use_cubin_header (
33533358 kspec .sm , kspec .head_size , kspec .dtype , kspec .output_dtype ,
3354- kspec .enable_skip_softmax ):
3359+ kspec .enable_skip_softmax , mask_type ):
33553360 continue
33563361 name = fname .replace ('.' , '_' )
33573362 data = 'extern unsigned char cubin_{name}_cubin[];' .format (name = name )
@@ -3487,7 +3492,8 @@ def get_cubin_header(kernel_traits, specs_names):
34873492 return_softmax_stats_flag = pythonBoolean2cpp [sm != '90' or (
34883493 sm == '90' and '_softmax' in kname )]
34893494
3490- enable_skip_softmax_flag = pythonBoolean2cpp ['_skipSoftmax' in kname ]
3495+ enable_skip_softmax = '_skipSoftmax' in kname
3496+ enable_skip_softmax_flag = pythonBoolean2cpp [enable_skip_softmax ]
34913497
34923498 # meta_unroll_step
34933499 meta_unroll_step = unroll_step if ('_nl' in kname
@@ -3516,7 +3522,8 @@ def get_cubin_header(kernel_traits, specs_names):
35163522 def get_lname_from_kname (kname : str ) -> str :
35173523 if use_cubin_header (int (sm ), int (head_size ), prec .lower (),
35183524 output_prec .lower (),
3519- enable_skip_softmax_flag ):
3525+ enable_skip_softmax ,
3526+ attention_mask_type ):
35203527 return 'nullptr'
35213528 lname = kname .replace ('_kernel' , '' )
35223529 mask_types = [
@@ -3537,9 +3544,9 @@ def get_lname_from_kname(kname: str) -> str:
35373544 {cubin_name}_len, \" {kname}\" , {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
35383545 {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
35393546 {is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
3540- ''' .format (** locals ()) if use_cubin_header (int ( sm ), int ( head_size ),
3541- prec .lower (), output_prec .lower (),
3542- enable_skip_softmax_flag ) else '''\
3547+ ''' .format (** locals ()) if use_cubin_header (
3548+ int ( sm ), int ( head_size ), prec .lower (), output_prec .lower (),
3549+ enable_skip_softmax , attention_mask_type ) else '''\
35433550 {{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
35443551 {sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
35453552 0, \" {kname}\" , {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
0 commit comments