Skip to content

Commit bfd980b

Browse files
committed
fix cuda launch error on devices < sm89
1 parent d097a13 commit bfd980b

File tree

1 file changed

+63
-50
lines changed

1 file changed

+63
-50
lines changed

spas_sage_attn/core.py

Lines changed: 63 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -67,26 +67,29 @@ def spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_caus
6767
assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
6868

6969
pvthreshd = hyperparameter_check(pvthreshd, q.size(-3), q.device)
70-
71-
## quant v
72-
b, h_kv, kv_len, head_dim = v.shape
73-
padded_len = (kv_len + 127) // 128 * 128
74-
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
75-
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
76-
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
77-
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
78-
#fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
79-
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)
80-
81-
_is_causal = 1 if is_causal else 0
8270
o = torch.empty_like(q)
83-
84-
if arch == "sm90":
85-
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
86-
elif SAGE2PP_ENABLED:
87-
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
71+
72+
if arch in ("sm80", "sm86", "sm87"):
73+
qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(
74+
q_int8, k_int8, v, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, 1, False, 1, scale, 0
75+
)
8876
else:
89-
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
77+
## quant v
78+
b, h_kv, kv_len, head_dim = v.shape
79+
padded_len = (kv_len + 127) // 128 * 128
80+
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
81+
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
82+
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
83+
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
84+
#fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
85+
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)
86+
87+
if arch == "sm90":
88+
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
89+
elif SAGE2PP_ENABLED:
90+
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
91+
else:
92+
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
9093

9194
if tensor_layout == 'NHD':
9295
o = rearrange(o, '... H L D -> ... L H D')
@@ -130,26 +133,29 @@ def spas_sage2_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is
130133
assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
131134

132135
pvthreshd = hyperparameter_check(pvthreshd, q.size(-3), q.device)
133-
134-
## quant v
135-
b, h_kv, kv_len, head_dim = v.shape
136-
padded_len = (kv_len + 127) // 128 * 128
137-
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
138-
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
139-
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
140-
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
141-
#fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
142-
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)
143-
144-
_is_causal = 1 if is_causal else 0
145136
o = torch.empty_like(q)
146-
147-
if arch == "sm90":
148-
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
149-
elif SAGE2PP_ENABLED:
150-
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
137+
138+
if arch in ("sm80", "sm86", "sm87"):
139+
qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(
140+
q_int8, k_int8, v, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, 1, False, 1, scale, 0
141+
)
151142
else:
152-
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
143+
## quant v
144+
b, h_kv, kv_len, head_dim = v.shape
145+
padded_len = (kv_len + 127) // 128 * 128
146+
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
147+
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
148+
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
149+
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
150+
#fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
151+
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)
152+
153+
if arch == "sm90":
154+
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
155+
elif SAGE2PP_ENABLED:
156+
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
157+
else:
158+
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
153159

154160
if tensor_layout == 'NHD':
155161
o = rearrange(o, '... H L D -> ... L H D')
@@ -194,22 +200,29 @@ def block_sparse_sage2_attn_cuda(q, k, v, mask_id=None, dropout_p=0.0, scale=Non
194200
assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
195201

196202
pvthreshd = hyperparameter_check(pvthreshd, q.size(-3), q.device)
197-
198-
## quant v
199-
b, h_kv, kv_len, head_dim = v.shape
200-
padded_len = (kv_len + 127) // 128 * 128
201-
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
202-
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
203-
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
204-
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
205-
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
206-
207203
o = torch.empty_like(q)
208-
209-
if arch == "sm90":
210-
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
204+
205+
if arch in ("sm80", "sm86", "sm87"):
206+
qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(
207+
q_int8, k_int8, v, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, 1, False, 1, scale, 0
208+
)
211209
else:
212-
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
210+
## quant v
211+
b, h_kv, kv_len, head_dim = v.shape
212+
padded_len = (kv_len + 127) // 128 * 128
213+
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
214+
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
215+
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
216+
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
217+
#fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1)
218+
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)
219+
220+
if arch == "sm90":
221+
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
222+
elif SAGE2PP_ENABLED:
223+
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
224+
else:
225+
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0)
213226

214227
if tensor_layout == 'NHD':
215228
o = rearrange(o, '... H L D -> ... L H D')

0 commit comments

Comments
 (0)