@@ -1164,6 +1164,7 @@ def register_graph_buffers(fa: int, handles: List[List[int]],
11641164 offsets : List [List [int ]]) -> None :
11651165 torch .ops ._C_custom_ar .register_graph_buffers (fa , handles , offsets )
11661166
1167+
11671168def get_flash_mla_metadata (
11681169 cache_seqlens : torch .Tensor ,
11691170 num_heads_per_head_k : int ,
@@ -1179,7 +1180,9 @@ def get_flash_mla_metadata(
11791180 tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
11801181 num_splits: (batch_size + 1), dtype torch.int32.
11811182 """
1182- return torch .ops ._C .get_flash_mla_metadata (cache_seqlens , num_heads_per_head_k , num_heads_k )
1183+ return torch .ops ._C .get_flash_mla_metadata (cache_seqlens ,
1184+ num_heads_per_head_k ,
1185+ num_heads_k )
11831186
11841187
11851188def flash_mla_with_kvcache (
@@ -1210,7 +1213,7 @@ def flash_mla_with_kvcache(
12101213 softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
12111214 """
12121215 if softmax_scale is None :
1213- softmax_scale = q .shape [- 1 ] ** (- 0.5 )
1216+ softmax_scale = q .shape [- 1 ]** (- 0.5 )
12141217 out , softmax_lse = torch .ops ._C .flash_mla_fwd_kvcache (
12151218 q ,
12161219 k_cache ,
0 commit comments