Skip to content

Commit 509511f

Browse files
committed
fix lint
Signed-off-by: wangxiyuan <[email protected]>
1 parent 9127256 commit 509511f

File tree

7 files changed

+160
-106
lines changed

7 files changed

+160
-106
lines changed

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
if HAS_TRITON:
2121
import vllm_ascend.patch.worker.patch_common.patch_triton
2222

23+
# isort: off
2324
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
2425
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
2526
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa

vllm_ascend/patch/worker/patch_common/patch_attentionspec.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def page_size_bytes(self) -> int:
3131

3232

3333
vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec
34-
from vllm.v1.kv_cache_interface import FullAttentionSpec
3534

3635

3736
@dataclass(frozen=True)

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.nn as nn
55
import torchair
6+
import vllm.envs as envs_vllm
67
from torchair import patch_for_hcom
78
from vllm.attention.layer import Attention
89
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
@@ -26,7 +27,6 @@
2627
TorchairCommonAttentionMetadata)
2728
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
2829
vllm_version_is)
29-
import vllm.envs as envs_vllm
3030

3131
PADDING_SLOT_ID = -1
3232

@@ -61,7 +61,8 @@ def __init__(
6161
self.torchair_compiled_models = {} # type: ignore
6262
self.torchair_graph_enabled = get_ascend_config(
6363
).torchair_graph_config.enabled
64-
self.enable_shared_expert_dp = get_ascend_config().enable_shared_expert_dp
64+
self.enable_shared_expert_dp = get_ascend_config(
65+
).enable_shared_expert_dp
6566
# We need +1 here because the arange is used to set query_start_loc,
6667
# which has one more element than batch_size.
6768
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
@@ -81,7 +82,9 @@ def load_model(self, model) -> None:
8182
with set_default_torch_dtype(
8283
draft_model_config.dtype), set_current_vllm_config(
8384
self.vllm_config):
84-
if self.torchair_graph_enabled or (self.enable_shared_expert_dp and self.vllm_config.model_config.use_mla):
85+
if self.torchair_graph_enabled or (
86+
self.enable_shared_expert_dp
87+
and self.vllm_config.model_config.use_mla):
8588
self.model = TorchairDeepSeekMTP(
8689
vllm_config=self.vllm_config).to(target_device)
8790
else:

vllm_ascend/torchair/models/torchair_deepseek_v2.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -528,17 +528,15 @@ def __init__(
528528
bias=False,
529529
quant_config=quant_config,
530530
prefix=f"{prefix}.o_proj",
531-
return_bias=False
532-
)
531+
return_bias=False)
533532
else:
534533
self.o_proj = TorchairDeepseekV2RowParallelLinear(
535534
self.num_heads * self.v_head_dim,
536535
self.hidden_size,
537536
bias=False,
538537
quant_config=quant_config,
539538
prefix=f"{prefix}.o_proj",
540-
return_bias=False
541-
)
539+
return_bias=False)
542540

543541
if rope_scaling:
544542
rope_scaling["rope_type"] = 'deepseek_yarn'
@@ -738,10 +736,10 @@ def __init__(
738736
return_bias=False,
739737
)
740738
if (config.n_routed_experts is not None
741-
and self.debug_layer_idx >= config.first_k_dense_replace
742-
and self.debug_layer_idx % config.moe_layer_freq == 0
743-
and (ascend_config.multistream_overlap_shared_expert
744-
or self.enable_shared_expert_dp)):
739+
and self.debug_layer_idx >= config.first_k_dense_replace
740+
and self.debug_layer_idx % config.moe_layer_freq == 0
741+
and (ascend_config.multistream_overlap_shared_expert
742+
or self.enable_shared_expert_dp)):
745743
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
746744
self.num_heads * self.v_head_dim,
747745
self.hidden_size,
@@ -827,8 +825,10 @@ def forward(
827825
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
828826
forward_context = get_forward_context()
829827
if not self.torchair_graph_enabled:
830-
if forward_context.attn_metadata is not None and isinstance(forward_context.attn_metadata, dict):
831-
attn_metadata = next(iter(forward_context.attn_metadata.values()), None)
828+
if forward_context.attn_metadata is not None and isinstance(
829+
forward_context.attn_metadata, dict):
830+
attn_metadata = next(
831+
iter(forward_context.attn_metadata.values()), None)
832832
else:
833833
attn_metadata = forward_context.attn_metadata
834834
if kv_cache is None:
@@ -843,7 +843,9 @@ def forward(
843843
# need_gather_q_kv = True
844844
if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace:
845845
output_shape = hidden_states.shape
846-
if self.enable_shared_expert_dp and (self.debug_layer_idx == self.first_k_dense_replace or self.debug_layer_idx ==self.layers):
846+
if self.enable_shared_expert_dp and (
847+
self.debug_layer_idx == self.first_k_dense_replace
848+
or self.debug_layer_idx == self.layers):
847849
rows = num_tokens // self.tp_size
848850
if num_tokens % self.tp_size:
849851
rows += 1

0 commit comments

Comments
 (0)