@@ -67,26 +67,29 @@ def spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_caus
6767 assert headdim in [64 , 128 ], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
6868
6969 pvthreshd = hyperparameter_check (pvthreshd , q .size (- 3 ), q .device )
70-
71- ## quant v
72- b , h_kv , kv_len , head_dim = v .shape
73- padded_len = (kv_len + 127 ) // 128 * 128
74- v_transposed_permutted = torch .empty ((b , h_kv , head_dim , padded_len ), dtype = v .dtype , device = v .device )
75- fused .transpose_pad_permute_cuda (v , v_transposed_permutted , 1 )
76- v_fp8 = torch .empty (v_transposed_permutted .shape , dtype = torch .float8_e4m3fn , device = v .device )
77- v_scale = torch .empty ((b , h_kv , head_dim ), dtype = torch .float32 , device = v .device )
78- #fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
79- fused .scale_fuse_quant_cuda (v_transposed_permutted , v_fp8 , v_scale , kv_len , 2.25 , 1 )
80-
81- _is_causal = 1 if is_causal else 0
8270 o = torch .empty_like (q )
83-
84- if arch == "sm90" :
85- qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90 ( q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
86- elif SAGE2PP_ENABLED :
87- qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold ( q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
71+
72+ if arch in ( "sm80" , "sm86" , "sm87" ) :
73+ qattn .qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold (
74+ q_int8 , k_int8 , v , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , 1 , False , 1 , scale , 0
75+ )
8876 else :
89- qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
77+ ## quant v
78+ b , h_kv , kv_len , head_dim = v .shape
79+ padded_len = (kv_len + 127 ) // 128 * 128
80+ v_transposed_permutted = torch .empty ((b , h_kv , head_dim , padded_len ), dtype = v .dtype , device = v .device )
81+ fused .transpose_pad_permute_cuda (v , v_transposed_permutted , 1 )
82+ v_fp8 = torch .empty (v_transposed_permutted .shape , dtype = torch .float8_e4m3fn , device = v .device )
83+ v_scale = torch .empty ((b , h_kv , head_dim ), dtype = torch .float32 , device = v .device )
84+ #fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
85+ fused .scale_fuse_quant_cuda (v_transposed_permutted , v_fp8 , v_scale , kv_len , 2.25 , 1 )
86+
87+ if arch == "sm90" :
88+ qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90 (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
89+ elif SAGE2PP_ENABLED :
90+ qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
91+ else :
92+ qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
9093
9194 if tensor_layout == 'NHD' :
9295 o = rearrange (o , '... H L D -> ... L H D' )
@@ -130,26 +133,29 @@ def spas_sage2_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is
130133 assert headdim in [64 , 128 ], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
131134
132135 pvthreshd = hyperparameter_check (pvthreshd , q .size (- 3 ), q .device )
133-
134- ## quant v
135- b , h_kv , kv_len , head_dim = v .shape
136- padded_len = (kv_len + 127 ) // 128 * 128
137- v_transposed_permutted = torch .empty ((b , h_kv , head_dim , padded_len ), dtype = v .dtype , device = v .device )
138- fused .transpose_pad_permute_cuda (v , v_transposed_permutted , 1 )
139- v_fp8 = torch .empty (v_transposed_permutted .shape , dtype = torch .float8_e4m3fn , device = v .device )
140- v_scale = torch .empty ((b , h_kv , head_dim ), dtype = torch .float32 , device = v .device )
141- #fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
142- fused .scale_fuse_quant_cuda (v_transposed_permutted , v_fp8 , v_scale , kv_len , 2.25 , 1 )
143-
144- _is_causal = 1 if is_causal else 0
145136 o = torch .empty_like (q )
146-
147- if arch == "sm90" :
148- qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90 ( q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
149- elif SAGE2PP_ENABLED :
150- qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold ( q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
137+
138+ if arch in ( "sm80" , "sm86" , "sm87" ) :
139+ qattn .qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold (
140+ q_int8 , k_int8 , v , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , 1 , False , 1 , scale , 0
141+ )
151142 else :
152- qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
143+ ## quant v
144+ b , h_kv , kv_len , head_dim = v .shape
145+ padded_len = (kv_len + 127 ) // 128 * 128
146+ v_transposed_permutted = torch .empty ((b , h_kv , head_dim , padded_len ), dtype = v .dtype , device = v .device )
147+ fused .transpose_pad_permute_cuda (v , v_transposed_permutted , 1 )
148+ v_fp8 = torch .empty (v_transposed_permutted .shape , dtype = torch .float8_e4m3fn , device = v .device )
149+ v_scale = torch .empty ((b , h_kv , head_dim ), dtype = torch .float32 , device = v .device )
150+ #fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
151+ fused .scale_fuse_quant_cuda (v_transposed_permutted , v_fp8 , v_scale , kv_len , 2.25 , 1 )
152+
153+ if arch == "sm90" :
154+ qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90 (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
155+ elif SAGE2PP_ENABLED :
156+ qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
157+ else :
158+ qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
153159
154160 if tensor_layout == 'NHD' :
155161 o = rearrange (o , '... H L D -> ... L H D' )
@@ -194,22 +200,29 @@ def block_sparse_sage2_attn_cuda(q, k, v, mask_id=None, dropout_p=0.0, scale=Non
194200 assert headdim in [64 , 128 ], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
195201
196202 pvthreshd = hyperparameter_check (pvthreshd , q .size (- 3 ), q .device )
197-
198- ## quant v
199- b , h_kv , kv_len , head_dim = v .shape
200- padded_len = (kv_len + 127 ) // 128 * 128
201- v_transposed_permutted = torch .empty ((b , h_kv , head_dim , padded_len ), dtype = v .dtype , device = v .device )
202- fused .transpose_pad_permute_cuda (v , v_transposed_permutted , 1 )
203- v_fp8 = torch .empty (v_transposed_permutted .shape , dtype = torch .float8_e4m3fn , device = v .device )
204- v_scale = torch .empty ((b , h_kv , head_dim ), dtype = torch .float32 , device = v .device )
205- fused .scale_fuse_quant_cuda (v_transposed_permutted , v_fp8 , v_scale , kv_len , 448.0 , 1 )
206-
207203 o = torch .empty_like (q )
208-
209- if arch == "sm90" :
210- qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90 (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
204+
205+ if arch in ("sm80" , "sm86" , "sm87" ):
206+ qattn .qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold (
207+ q_int8 , k_int8 , v , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , 1 , False , 1 , scale , 0
208+ )
211209 else :
212- qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
210+ ## quant v
211+ b , h_kv , kv_len , head_dim = v .shape
212+ padded_len = (kv_len + 127 ) // 128 * 128
213+ v_transposed_permutted = torch .empty ((b , h_kv , head_dim , padded_len ), dtype = v .dtype , device = v .device )
214+ fused .transpose_pad_permute_cuda (v , v_transposed_permutted , 1 )
215+ v_fp8 = torch .empty (v_transposed_permutted .shape , dtype = torch .float8_e4m3fn , device = v .device )
216+ v_scale = torch .empty ((b , h_kv , head_dim ), dtype = torch .float32 , device = v .device )
217+ #fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
218+ fused .scale_fuse_quant_cuda (v_transposed_permutted , v_fp8 , v_scale , kv_len , 2.25 , 1 )
219+
220+ if arch == "sm90" :
221+ qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90 (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
222+ elif SAGE2PP_ENABLED :
223+ qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
224+ else :
225+ qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (q_int8 , k_int8 , v_fp8 , o , lut , valid_block_num , pvthreshd , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0 )
213226
214227 if tensor_layout == 'NHD' :
215228 o = rearrange (o , '... H L D -> ... L H D' )
0 commit comments