Skip to content

Commit 3ae4a6e

Browse files
format
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent a6c36cc commit 3ae4a6e

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

vllm/_custom_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,7 @@ def register_graph_buffers(fa: int, handles: List[List[int]],
11641164
offsets: List[List[int]]) -> None:
11651165
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
11661166

1167+
11671168
def get_flash_mla_metadata(
11681169
cache_seqlens: torch.Tensor,
11691170
num_heads_per_head_k: int,
@@ -1179,7 +1180,9 @@ def get_flash_mla_metadata(
11791180
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
11801181
num_splits: (batch_size + 1), dtype torch.int32.
11811182
"""
1182-
return torch.ops._C.get_flash_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
1183+
return torch.ops._C.get_flash_mla_metadata(cache_seqlens,
1184+
num_heads_per_head_k,
1185+
num_heads_k)
11831186

11841187

11851188
def flash_mla_with_kvcache(
@@ -1210,7 +1213,7 @@ def flash_mla_with_kvcache(
12101213
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
12111214
"""
12121215
if softmax_scale is None:
1213-
softmax_scale = q.shape[-1] ** (-0.5)
1216+
softmax_scale = q.shape[-1]**(-0.5)
12141217
out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache(
12151218
q,
12161219
k_cache,

0 commit comments

Comments
 (0)