@@ -1055,7 +1055,6 @@ def _dual_chunk_flash_attn_prefill_func(
1055
1055
v_states_intra ,
1056
1056
softmax_scale = softmax_scale ,
1057
1057
causal = True ,
1058
- block_table = block_table ,
1059
1058
stage = "intra" ,
1060
1059
vertical_indices = vertical_buffer ,
1061
1060
slash_indices = slash_buffer ,
@@ -1070,7 +1069,6 @@ def _dual_chunk_flash_attn_prefill_func(
1070
1069
v_states_intra ,
1071
1070
softmax_scale = softmax_scale ,
1072
1071
causal = True ,
1073
- block_table = block_table ,
1074
1072
stage = "intra" ,
1075
1073
vertical_indices = intra_vertical_indices ,
1076
1074
slash_indices = intra_slash_indices ,
@@ -1085,7 +1083,6 @@ def _dual_chunk_flash_attn_prefill_func(
1085
1083
v_states_succ ,
1086
1084
softmax_scale = softmax_scale ,
1087
1085
causal = False ,
1088
- block_table = block_table ,
1089
1086
stage = "succ" ,
1090
1087
vertical_indices = succ_vertical_buffer ,
1091
1088
slash_indices = succ_slash_buffer ,
@@ -1100,7 +1097,6 @@ def _dual_chunk_flash_attn_prefill_func(
1100
1097
v_states_succ ,
1101
1098
softmax_scale = softmax_scale ,
1102
1099
causal = False ,
1103
- block_table = block_table ,
1104
1100
stage = "succ" ,
1105
1101
vertical_indices = succ_vertical_indices ,
1106
1102
slash_indices = succ_slash_indices ,
@@ -1115,7 +1111,6 @@ def _dual_chunk_flash_attn_prefill_func(
1115
1111
v_states_inter ,
1116
1112
softmax_scale = softmax_scale ,
1117
1113
causal = False ,
1118
- block_table = block_table ,
1119
1114
stage = "inter" ,
1120
1115
vertical_indices = inter_vertical_buffer ,
1121
1116
slash_indices = inter_slash_buffer ,
@@ -1130,7 +1125,6 @@ def _dual_chunk_flash_attn_prefill_func(
1130
1125
v_states_inter ,
1131
1126
softmax_scale = softmax_scale ,
1132
1127
causal = False ,
1133
- block_table = block_table ,
1134
1128
stage = "inter" ,
1135
1129
vertical_indices = inter_vertical_indices ,
1136
1130
slash_indices = inter_slash_indices ,
@@ -1151,7 +1145,6 @@ def _do_flash_attn(
1151
1145
value_states : torch .Tensor ,
1152
1146
softmax_scale : float ,
1153
1147
causal : bool = True ,
1154
- block_table : torch .Tensor = None ,
1155
1148
max_seqlen_k : Optional [int ] = None ,
1156
1149
stage : str = "intra" ,
1157
1150
vertical_indices : Optional [torch .Tensor ] = None ,
@@ -1230,7 +1223,6 @@ def _do_flash_attn(
1230
1223
device = query_states .device ),
1231
1224
max_seqlen_k = max_seqlen_k ,
1232
1225
causal = causal ,
1233
- block_table = block_table .unsqueeze (0 ),
1234
1226
return_softmax_lse = True ,
1235
1227
)
1236
1228
softmax_lse = softmax_lse .view (q_len , q_heads , 1 ).transpose (0 ,
0 commit comments