Skip to content

Commit 3cca936

Browse files
committed
refact model runner
Signed-off-by: weiguihua2 <[email protected]>
1 parent ab1d21f commit 3cca936

File tree

6 files changed

+42
-21
lines changed

6 files changed

+42
-21
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
AscendMLAImpl, AscendMLAMetadata,
1212
AscendMLAMetadataBuilder,
1313
AscendMLAPrefillMetadata)
14-
from vllm_ascend.attention.utils import TorchairCommonAttentionMetadata
14+
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
1515

1616

1717
class TestAscendMLABackend(TestBase):

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
from vllm_ascend.ascend_config import get_ascend_config
2020
from vllm_ascend.attention.attention_v1 import AscendAttentionState
2121
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
22-
TorchairCommonAttentionMetadata,
2322
split_decodes_and_prefills)
2423
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2524
from vllm_ascend.multistream.context import get_multistream_comm_context
2625
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2726
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
28-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
27+
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
28+
npu_stream_switch, npu_wait_tensor)
2929
from vllm_ascend.utils import npu_prefetch
3030
from vllm_ascend.worker.npu_input_batch import InputBatch
3131

vllm_ascend/torchair/torchair_attention.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,19 @@
2222
import torch
2323
import torch.nn as nn
2424
import torch_npu
25-
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
26-
AttentionLayer, AttentionType)
27-
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
28-
from vllm.config import VllmConfig
29-
from vllm.utils import cdiv
30-
from vllm.v1.core.sched.output import SchedulerOutput
31-
32-
from vllm_ascend.attention.attention_v1 import AscendAttentionState
3325
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
3426
AttentionType)
3527
from vllm.attention.backends.utils import PAD_SLOT_ID
28+
from vllm.utils import cdiv
3629

3730
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
3831
AscendAttentionMetadataBuilder,
3932
AscendAttentionState,
4033
AscendMetadata)
4134
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
42-
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
43-
TorchairCommonAttentionMetadata)
35+
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
4436
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
4537
nd_to_nz_2d)
46-
from vllm_ascend.worker.npu_input_batch import InputBatch
4738

4839

4940
class AscendAttentionTorchairBackend(AscendAttentionBackend):
@@ -108,9 +99,12 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
10899

109100
def __init__(self, runner):
110101
super().__init__(runner)
111-
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
112-
vllm_config.cache_config.block_size)
113-
self.max_blocks = (self.model_config.max_model_len + vllm_config.cache_config.block_size - 1) // vllm_config.cache_config.block_size
102+
self.max_num_blocks_per_req = cdiv(
103+
self.model_config.max_model_len,
104+
self.vllm_config.cache_config.block_size)
105+
self.max_blocks = (self.model_config.max_model_len +
106+
self.vllm_config.cache_config.block_size -
107+
1) // self.vllm_config.cache_config.block_size
114108

115109
def _get_graph_runner_block_tables(
116110
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from vllm.forward_context import get_forward_context
2626
from vllm.logger import logger
2727

28-
from vllm_ascend.attention.utils import TorchairCommonAttentionMetadata
2928
from vllm_ascend.platform import NPUPlatform
30-
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
29+
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
30+
check_torchair_cache_exist,
3131
register_torchair_model,
3232
write_kv_cache_bytes_to_file)
3333
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,

vllm_ascend/torchair/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import shutil
44
from contextlib import contextmanager, nullcontext
5+
from dataclasses import dataclass
56

67
import torch
78

@@ -20,6 +21,32 @@
2021
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
2122

2223

24+
@dataclass
25+
class TorchairCommonAttentionMetadata:
26+
"""
27+
Per-batch attention metadata, shared across layers and backends.
28+
AttentionMetadataBuilder instances use it to construct per-layer metadata.
29+
30+
For many of the tensors we keep both GPU and CPU versions.
31+
"""
32+
33+
num_reqs: int
34+
"""Number of requests"""
35+
36+
num_actual_tokens: int
37+
"""Total number of tokens in batch"""
38+
39+
decode_token_per_req: int
40+
41+
actual_seq_lengths_q: list[int]
42+
43+
attn_mask: torch.Tensor = None
44+
45+
spec_attn_mask: torch.Tensor = None
46+
47+
graph_pad_size: int = -1
48+
49+
2350
@contextmanager
2451
def _file_lock(file_descriptor, lock_type):
2552
fcntl.flock(file_descriptor, lock_type)

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
from vllm_ascend.ascend_config import get_ascend_config
1818
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
19-
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
20-
TorchairCommonAttentionMetadata)
19+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
2120
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
21+
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
2222
from vllm_ascend.utils import ProfileExecuteDuration
2323

2424

0 commit comments

Comments
 (0)