diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index e8fe7ab6bf..ab593414ef 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -9,6 +9,7 @@ AscendAttentionState, AscendMetadata, CommonAttentionState) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata class TestAscendAttentionBackend(TestBase): @@ -67,8 +68,12 @@ def test_copy_blocks(self): class TestAscendAttentionMetadataBuilder(TestBase): def setUp(self): - self.mock_runner = MagicMock() - self.builder = AscendAttentionMetadataBuilder(self.mock_runner) + self.mock_vllm_config = MagicMock() + self.mock_vllm_config.model_config.max_model_len = 640 + self.mock_vllm_config.cache_config.block_size = 64 + self.mock_device = 'cpu:0' + self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config, + self.mock_device) def test_reorder_batch(self): mock_input_batch = MagicMock() @@ -86,31 +91,28 @@ def test_reorder_batch(self): def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, mock_npu_format_cast, mock_ascend_metadata): - num_reqs = 2 - num_actual_tokens = 10 - max_query_len = 5 - - self.mock_runner.input_batch.block_table = [MagicMock()] - self.mock_runner.input_batch.block_table[ - 0].get_device_tensor.return_value = torch.zeros((10, 10)) - self.mock_runner.max_num_blocks_per_req = 10 - self.mock_runner.query_lens = torch.tensor([3, 4]) - self.mock_runner.seq_lens_cpu = torch.tensor([5, 6]) - self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) - self.mock_runner.device = 'cpu:0' - self.mock_runner.attn_mask = torch.ones((10, 10)) - self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache - self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7]) + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 3, 7]), + query_start_loc_cpu=torch.tensor([0, 3, 7]), + seq_lens_cpu=torch.tensor([5, 6]), + num_reqs=2, + num_actual_tokens=10, + max_query_len=5, + decode_token_per_req=torch.tensor([1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping_cpu=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((10, 10)), + spec_attn_mask=None, + attn_state=AscendAttentionState.PrefillNoCache) mock_nz_tensor = MagicMock() + mock_model = MagicMock() mock_nd_to_nz_2d.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build( - num_reqs, - num_actual_tokens, - max_query_len, - ) + self.builder.build(common_attn_metadata, mock_model) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('torch_npu.npu_format_cast') @@ -120,51 +122,53 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, def test_build_chunked_prefill(self, mock_ascend_attention_state, mock_is_310p, mock_nd_to_nz_spec, mock_npu_format_cast, mock_ascend_metadata): - num_reqs = 3 - num_actual_tokens = 15 - max_query_len = 6 - - self.mock_runner.input_batch.block_table = [MagicMock()] - self.mock_runner.input_batch.block_table[ - 0].get_device_tensor.return_value = torch.zeros((10, 10)) - self.mock_runner.max_num_blocks_per_req = 10 - self.mock_runner.query_lens = torch.tensor([2, 3, 4]) - self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6]) - self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) - self.mock_runner.device = 'cpu:0' - self.mock_runner.attn_mask = torch.ones((15, 15)) - self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill - self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 2, 5, 9]), + query_start_loc_cpu=torch.tensor([0, 2, 5, 9]), + seq_lens_cpu=torch.tensor([4, 5, 6]), + num_reqs=3, + num_actual_tokens=15, + max_query_len=6, + decode_token_per_req=torch.tensor([1, 1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping_cpu=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1, 2]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((15, 15)), + spec_attn_mask=None, + attn_state=AscendAttentionState.ChunkedPrefill) mock_ascend_attention_state = MagicMock() mock_ascend_attention_state.PrefillNoCache = 0 mock_nz_tensor = MagicMock() + mock_model = MagicMock() mock_nd_to_nz_spec.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build(num_reqs, num_actual_tokens, max_query_len) + self.builder.build(common_attn_metadata, mock_model) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata): - num_reqs = 3 - num_actual_tokens = 15 - max_query_len = 6 - - self.mock_runner.input_batch.block_table = [MagicMock()] - self.mock_runner.input_batch.block_table[ - 0].get_device_tensor.return_value = torch.zeros((10, 10)) - self.mock_runner.max_num_blocks_per_req = 10 - self.mock_runner.query_lens = torch.tensor([2, 3, 4]) - self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6]) - self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) - self.mock_runner.device = 'cpu:0' - self.mock_runner.attn_mask = torch.ones((15, 15)) - self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill - self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) - - self.builder.build(num_reqs, num_actual_tokens, max_query_len) + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 2, 5, 9]), + query_start_loc_cpu=torch.tensor([0, 2, 5, 9]), + seq_lens_cpu=torch.tensor([4, 5, 6]), + num_reqs=3, + num_actual_tokens=15, + max_query_len=6, + decode_token_per_req=torch.tensor([1, 1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping_cpu=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1, 2]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((15, 15)), + spec_attn_mask=None, + attn_state=AscendAttentionState.ChunkedPrefill) + mock_model = MagicMock() + + self.builder.build(common_attn_metadata, mock_model) class TestAscendAttentionBackendImpl(TestBase): diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 497b7b53ab..be2a7d897e 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -1,6 +1,5 @@ from unittest.mock import MagicMock, patch -import numpy as np import torch from vllm.distributed.parallel_state import GroupCoordinator from vllm.model_executor.layers.linear import LinearBase @@ -12,6 +11,7 @@ AscendMLAImpl, AscendMLAMetadata, AscendMLAMetadataBuilder, AscendMLAPrefillMetadata) +from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata class TestAscendMLABackend(TestBase): @@ -178,40 +178,41 @@ def test_ascend_mla_metadata_default(self): class TestAscendMLAMetadataBuilder(TestBase): def test_ascend_mla_metadata_builder_default(self): - runner = MagicMock() - runner.scheduler_config = MagicMock() - runner.model_config = MagicMock() - runner.scheduler_config.max_num_seqs = 4 - runner.model_config.max_model_len = 1024 - runner.model_config.get_head_size.return_value = 64 - runner.model_config.dtype = torch.float16 - runner.chunked_prefill_enabled = False - runner.device = "cpu" - runner.block_size = 16 - runner.decode_token_per_req = 1 + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' ascend_config = MagicMock() ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): - builder = AscendMLAMetadataBuilder(runner) + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) - self.assertEqual(builder.runner, runner) - self.assertEqual(builder.block_size, runner.block_size) - self.assertEqual(builder.chunked_prefill_enabled, - runner.chunked_prefill_enabled) + self.assertEqual(builder.block_size, + mock_vllm_config.cache_config.block_size) + self.assertEqual( + builder.chunked_prefill_enabled, + mock_vllm_config.scheduler_config.chunked_prefill_enabled) self.assertEqual(builder.torchair_graph_enabled, True) @patch("vllm_ascend.attention.mla_v1.get_ascend_config") def test_reorder_batch_with_torchair_graph(self, ascend_config): - runner = MagicMock() - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True - builder = AscendMLAMetadataBuilder(runner) + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) input_batch = MagicMock() input_batch.req_ids = [0, 1, 2, 3] @@ -230,22 +231,23 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config): modified = builder.reorder_batch(input_batch, scheduler_output) self.assertFalse(modified) - self.assertEqual(builder._num_decodes, 4) - self.assertEqual(builder._num_prefills, 0) - self.assertEqual(builder._num_decode_tokens, 7) - self.assertEqual(builder._num_prefill_tokens, 0) input_batch.swap_states.assert_not_called() def test_reorder_batch_without_torchair_graph(self): ascend_config = MagicMock() - runner = MagicMock() - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = False + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): - builder = AscendMLAMetadataBuilder(runner) + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) input_batch = MagicMock() input_batch.req_ids = [0, 1, 2, 3] @@ -264,10 +266,6 @@ def test_reorder_batch_without_torchair_graph(self): modified = builder.reorder_batch(input_batch, scheduler_output) self.assertTrue(modified) - self.assertEqual(builder._num_decodes, 2) - self.assertEqual(builder._num_prefills, 2) - self.assertEqual(builder._num_decode_tokens, 2) - self.assertEqual(builder._num_prefill_tokens, 5) input_batch.swap_states.assert_called_once_with(1, 2) @patch("vllm_ascend.attention.mla_v1.get_ascend_config") @@ -275,11 +273,13 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config): ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - runner = MagicMock() - runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32) - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 - builder = AscendMLAMetadataBuilder(runner=runner) + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) result = builder._get_graph_runner_block_tables(3, block_tables) @@ -292,11 +292,13 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config): ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - runner = MagicMock() - runner.graph_block_tables = torch.zeros((8, 4), dtype=torch.int32) - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 - builder = AscendMLAMetadataBuilder(runner=runner) + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 64 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) result = builder._get_graph_runner_block_tables(3, block_tables) @@ -310,11 +312,13 @@ def test_get_graph_runner_block_tables_from_numpy(self, ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - runner = MagicMock() - runner.graph_block_tables = np.zeros((8, 64), dtype=np.int32) - runner.chunked_prefill_enabled = False - runner.decode_token_per_req = 1 - builder = AscendMLAMetadataBuilder(runner=runner) + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) @@ -329,38 +333,45 @@ def test_build_dummy(self, mock_ascend_config): ascend_config = MagicMock() mock_ascend_config.return_value = ascend_config ascend_config.torchair_graph_config.enabled = False - runner = MagicMock() - runner.model_config = MagicMock() - runner.device = "cpu" - runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32) - runner.model_config.get_head_size.return_value = 64 - runner.chunked_prefill_enabled = False - runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool) - runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool) - runner.dtype = torch.float16 - runner.decode_token_per_req = 1 - - builder = AscendMLAMetadataBuilder(runner=runner, + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_vllm_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_device = 'cpu' + + builder = AscendMLAMetadataBuilder(mock_vllm_config, + mock_device, metadata_cls=AscendMLAMetadata) builder.rope_dim = 64 with patch.object(builder, "_get_graph_runner_block_tables", side_effect=lambda x, y: y): - metadata = builder.build_torchair_graph_dummy(3, 3) + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=3, + num_actual_tokens=3, + decode_token_per_req=1, + actual_seq_lengths_q=[0, 1, 2], + attn_mask=torch.zeros((1, 1), dtype=torch.bool), + spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool), + ) + metadata = builder.build_torchair_graph_dummy(common_attn_metadata) sin_golden = torch.ones(3, 1, 1, 64, - dtype=runner.dtype, - device=runner.device) + dtype=torch.float16, + device=mock_device) cos_golden = torch.ones(3, 1, 1, 64, - dtype=runner.dtype, - device=runner.device) + dtype=torch.float16, + device=mock_device) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_input_tokens, 3) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 81a337563e..87d698509e 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -20,14 +20,17 @@ from typing import List, Optional, Tuple, Type import torch +import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils import direct_register_custom_op +from vllm.utils import cdiv, direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) @@ -157,35 +160,49 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: - def __init__(self, runner): - self.runner = runner + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, + vllm_config.cache_config.block_size) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, - num_reqs, - num_actual_tokens, - max_query_len, - enable_dbo_across_dp: bool = False, - is_only_prefill: bool = False, - *args, - **kwargs): - - block_table = self.runner.input_batch.block_table[0].get_device_tensor( - ) - block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + def build( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] + + block_table = common_attn_metadata.block_table_tensor + block_table[:num_reqs, :self.max_num_blocks_per_req] = ( block_table[:num_reqs]) - query_lens = self.runner.query_lens - seq_lens = self.runner.seq_lens_cpu[:num_reqs] - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True) - attn_mask = self.runner.attn_mask - attn_state = self.runner.attn_state - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping_cpu[: + num_actual_tokens].to( + self.device, + non_blocking= + True) + attn_mask = common_attn_metadata.attn_mask + attn_state = common_attn_metadata.attn_state + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] + query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) if is_310p(): @@ -204,12 +221,12 @@ def build(self, query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, - max_query_len=max_query_len, + max_query_len=common_attn_metadata.max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=enable_dbo_across_dp, - is_only_prefill=is_only_prefill) + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + is_only_prefill=common_attn_metadata.is_only_prefill) return attn_metadata diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a52d117b43..72a2d4f783 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -3,12 +3,13 @@ import numpy as np import torch +import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -17,11 +18,14 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, + npu_stream_switch, npu_wait_tensor) from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch @@ -172,20 +176,24 @@ class AscendMLAMetadataBuilder: # _attn_mask_builder = None def __init__(self, - runner, + vllm_config: VllmConfig, + device: torch.device, metadata_cls: Optional[AscendMLAMetadata] = None): self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \ if metadata_cls is not None else AscendMLAMetadata # type: ignore - self.runner = runner - scheduler_config = runner.scheduler_config - model_config = runner.model_config - self.block_size = runner.block_size - self.chunked_prefill_enabled = runner.chunked_prefill_enabled + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request - max(8 * model_config.max_model_len, + max(8 * self.model_config.max_model_len, 4 * scheduler_config.max_num_seqs * self.block_size), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, @@ -200,13 +208,13 @@ def __init__(self, scheduler_config.max_num_seqs * self.block_size self.chunked_prefill_workspace = torch.empty( (self.chunked_prefill_workspace_size, - model_config.get_head_size()), - dtype=model_config.dtype, - device=runner.device, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None @@ -220,8 +228,6 @@ def reorder_batch(self, input_batch: "InputBatch", # better naming here) decodes = [] prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] @@ -231,18 +237,14 @@ def reorder_batch(self, input_batch: "InputBatch", if self.torchair_graph_enabled: if num_tokens - num_spec_tokens == 1: decodes.append(i) - num_decode_tokens += num_tokens else: prefills.append(i) - num_prefill_tokens += num_tokens # For eager mode we treat spec decoding as chunked prefill. else: if num_tokens == 1: decodes.append(i) - num_decode_tokens += num_tokens else: prefills.append(i) - num_prefill_tokens += num_tokens # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are @@ -273,26 +275,15 @@ def reorder_batch(self, input_batch: "InputBatch", # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - return modified_batch def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}" - - if isinstance(self.runner.graph_block_tables, np.ndarray): - graph_block_tables = torch.zeros((max_batch_size, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - else: - graph_block_tables = self.runner.graph_block_tables.to( - device=block_tables.device, dtype=block_tables.dtype) + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) num_blocks = block_tables.size(1) if num_blocks <= max_blocks: @@ -304,18 +295,20 @@ def _get_graph_runner_block_tables( max_blocks] = block_tables[:num_seqs, : max_blocks] - return graph_block_tables[:num_seqs, :max_blocks] + return graph_block_tables[:, :max_blocks] def build_torchair_graph_dummy( - self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata: - device = self.runner.device - _, max_blocks = self.runner.graph_block_tables.shape - block_table = torch.zeros((num_reqs, max_blocks), + self, + common_attn_metadata: TorchairCommonAttentionMetadata, + ) -> AscendMLAMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs + block_table = torch.zeros((num_reqs, self.max_blocks), dtype=torch.int32, device=device) block_table = self._get_graph_runner_block_tables( num_reqs, block_table) - num_tokens = num_reqs * self.runner.decode_token_per_req + num_tokens = num_reqs * common_attn_metadata.decode_token_per_req seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) seq_lens_list = [0] * num_reqs input_positions = torch.zeros(num_tokens, @@ -333,16 +326,16 @@ def build_torchair_graph_dummy( 1, 1, self.rope_dim, - dtype=self.runner.dtype, + dtype=self.model_config.dtype, device=device) cos = torch.ones(num_tokens, 1, 1, self.rope_dim, - dtype=self.runner.dtype, + dtype=self.model_config.dtype, device=device) - if self.runner.speculative_config is not None and\ - self.runner.speculative_config.method == 'deepseek_mtp': + if self.vllm_config.speculative_config is not None and\ + self.vllm_config.speculative_config.method == 'deepseek_mtp': attn_state = AscendAttentionState.SpecDecoding num_decode_tokens = 2 else: @@ -354,20 +347,21 @@ def build_torchair_graph_dummy( seq_lens=seq_lens, seq_lens_list=seq_lens_list, max_seq_lens=1, - attn_mask=self.runner.spec_attn_mask, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q[:num_reqs], + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=common_attn_metadata. + actual_seq_lengths_q[:num_reqs], sin=sin, cos=cos, ) return self.metadata_cls( # type: ignore - num_input_tokens=num_actual_tokens, - num_actual_tokens=num_actual_tokens, + num_input_tokens=common_attn_metadata.num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, slot_mapping=slot_mapping, - head_dim=self.runner.model_config.get_head_size(), + head_dim=self.model_config.get_head_size(), num_decodes=1, num_decode_tokens=num_decode_tokens, num_prefills=0, - attn_mask=self.runner.attn_mask, + attn_mask=common_attn_metadata.attn_mask, attn_state=attn_state, prefill=None, decode=decode_metadata, @@ -378,58 +372,68 @@ def build_torchair_graph_dummy( def build( self, - num_reqs: int, - num_actual_tokens: int, - max_query_len: int, - graph_pad_size: int = -1, - query_start_loc: torch.Tensor = None, - enable_dbo_across_dp: bool = False, - *args, - **kwargs, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, ) -> AscendMLAMetadata: - assert self._num_decodes + self._num_prefills == num_reqs + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + decode_threshold = common_attn_metadata.decode_token_per_req + else: + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. - device = self.runner.device - - block_table = (self.runner.input_batch.block_table[0]. - get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True) - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() - - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[: - num_reqs] - seq_lens = seq_lens_cpu - max_query_len = query_lens.max().item() - max_seq_lens = seq_lens.max().item() + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping_cpu[: + num_actual_tokens].to( + device, + non_blocking= + True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + if self.cos_cache is None: - self.cos_cache = self.runner.get_model( - ).model.layers[0].self_attn.rotary_emb.cos_cached - self.sin_cache = self.runner.get_model( - ).model.layers[0].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.runner.dtype: # type: ignore + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore self.cos_cache = self.cos_cache.to( # type: ignore - self.runner.dtype) # type: ignore + self.model_config.dtype) # type: ignore self.sin_cache = self.sin_cache.to( # type: ignore - self.runner.dtype) # type: ignore + self.model_config.dtype) # type: ignore + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) prefill_metadata = None chunked_context_metadata = None - if self._num_prefills > 0: - reqs_start = self._num_decodes # prefill_start - tokens_start = self._num_decode_tokens + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens max_query_len = query_lens[tokens_start:].max().item() max_seq_lens = seq_lens[tokens_start:].max().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] - context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[ - reqs_start:num_reqs] + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() if self.chunked_prefill_enabled and max_context_len_cpu > 0: @@ -441,12 +445,12 @@ def build( assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) cu_seq_lens_cpu = torch.zeros(num_chunks, - self._num_prefills + 1, + num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(chunk_seq_lens, @@ -470,7 +474,7 @@ def build( prefill_input_positions].unsqueeze( # type: ignore 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( - attn_mask=self.runner.attn_mask, + attn_mask=common_attn_metadata.attn_mask, query_lens=query_lens[tokens_start:], seq_lens=seq_lens, context_lens=seq_lens[tokens_start:], @@ -485,14 +489,15 @@ def build( ) decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size use_torchair_graph = graph_pad_size != -1 - if self._num_decodes > 0: + if num_decodes > 0: actual_seq_lengths_q = query_start_loc[1:].tolist() - max_seq_lens = seq_lens[:self._num_decodes].max().item() - seq_lens = seq_lens[:self._num_decode_tokens] - input_positions = input_positions[:self._num_decode_tokens] - block_table = block_table[:self._num_decode_tokens, ...] - if use_torchair_graph and self.runner.attn_state in [ + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decode_tokens] + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decode_tokens, ...] + if use_torchair_graph and common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ]: @@ -500,10 +505,10 @@ def build( num_token_pad_size = 0 if graph_pad_size != 0: pad_value = 0 - num_token_pad_size = graph_pad_size - self._num_decode_tokens + num_token_pad_size = graph_pad_size - num_decode_tokens num_reqs_pad_size = ( - graph_pad_size // self.runner.decode_token_per_req - - num_reqs) + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) padded_seq_lens = seq_lens.tolist( ) + [pad_value] * num_reqs_pad_size else: @@ -531,14 +536,14 @@ def build( input_positions = torch.cat( [input_positions, position_padding]) actual_seq_lengths_q = query_start_loc[1:].tolist( - ) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs + - num_reqs_pad_size] + ) + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] else: seq_lens_list = seq_lens.tolist() # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) batch_size = slot_mapping.size(0) if actual_seq_lengths_q[-1] != batch_size \ - and self.runner.attn_state == AscendAttentionState.SpecDecoding: + and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: actual_seq_lengths_q[-1] = batch_size cos = self.cos_cache[input_positions].unsqueeze( # type: ignore @@ -552,7 +557,7 @@ def build( seq_lens=seq_lens, seq_lens_list=seq_lens_list, max_seq_lens=max_seq_lens, - attn_mask=self.runner.spec_attn_mask, + attn_mask=common_attn_metadata.spec_attn_mask, actual_seq_lengths_q=actual_seq_lengths_q, sin=sin, cos=cos) @@ -561,18 +566,18 @@ def build( num_actual_tokens=num_actual_tokens, query_lens=query_lens.tolist(), slot_mapping=slot_mapping, - head_dim=self.runner.model_config.get_head_size(), - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, - attn_mask=self.runner.attn_mask, - attn_state=self.runner.attn_state, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, prefill=prefill_metadata, decode=decode_metadata, query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - enable_dbo_across_dp=enable_dbo_across_dp, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py new file mode 100644 index 0000000000..42f9fa8550 --- /dev/null +++ b/vllm_ascend/attention/utils.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass +from typing import Any + +import torch + + +@dataclass +class AscendCommonAttentionMetadata: + """ + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. + """ + + query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor + """(batch_size + 1,), the start location of each request in query Tensor""" + + seq_lens_cpu: torch.Tensor + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" + + num_reqs: int + """Number of requests""" + num_actual_tokens: int + """Total number of tokens in batch""" + + max_query_len: int + """Max token number of request in batch""" + + decode_token_per_req: int + """decode token number per request""" + + block_table_tensor: torch.Tensor + + slot_mapping_cpu: torch.Tensor + + actual_seq_lengths_q: list[int] + + positions: torch.Tensor = None + + attn_mask: torch.Tensor = None + + spec_attn_mask: torch.Tensor = None + + attn_state: Any = None + + enable_dbo_across_dp: bool = False + + is_only_prefill: bool = False + + graph_pad_size: int = -1 + + +def split_decodes_and_prefills( + common_attn_metadata: AscendCommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: AscendCommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, num_tokens, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, num_tokens, 0 + + first_prefill = is_prefill.int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] >= decode_threshold) + assert torch.all(query_lens[:first_prefill] <= decode_threshold) + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = query_start_loc[first_prefill].item() + num_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) diff --git a/vllm_ascend/lora/punica_wrapper/lora_ops.py b/vllm_ascend/lora/punica_wrapper/lora_ops.py index a8ff21d748..e8bf8ad971 100644 --- a/vllm_ascend/lora/punica_wrapper/lora_ops.py +++ b/vllm_ascend/lora/punica_wrapper/lora_ops.py @@ -52,14 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - return torch.ops._C.bgmv_expand( - inputs, - lora_b_weights, - lora_indices_tensor, - output_tensor, - slice_offset, - slice_size - ) + return torch.ops._C.bgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, output_tensor, + slice_offset, slice_size) def sgmv_shrink( @@ -74,8 +69,9 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, lora_indices_tensor, - seq_len_tensor, output_tensor, scaling) + return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, scaling) def sgmv_expand(inputs: torch.Tensor, @@ -111,12 +107,6 @@ def sgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = False): - return torch.ops._C.sgmv_expand( - inputs, - lora_b_weights, - lora_indices_tensor, - seq_len_tensor, - output_tensor, - slice_offset, - slice_size - ) + return torch.ops._C.sgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, slice_offset, slice_size) diff --git a/vllm_ascend/meta_registration.py b/vllm_ascend/meta_registration.py index f292e61423..47c775887d 100644 --- a/vllm_ascend/meta_registration.py +++ b/vllm_ascend/meta_registration.py @@ -80,23 +80,18 @@ def get_masked_input_and_mask_meta(input: torch.Tensor, return masked_input, mask -def bgmv_expand_meta(x: torch.Tensor, - weight: torch.Tensor, - indices: torch.Tensor, - y: torch.Tensor, - slice_offset: int, - slice_size: int): + +def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, + indices: torch.Tensor, y: torch.Tensor, slice_offset: int, + slice_size: int): y_out = torch.empty_like(y) return y_out -def sgmv_expand_meta(x: torch.Tensor, - weight: torch.Tensor, - lora_indices: torch.Tensor, - seq_len: torch.Tensor, - y: torch.Tensor, - slice_offset: int, - slice_size: int): + +def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, + lora_indices: torch.Tensor, seq_len: torch.Tensor, + y: torch.Tensor, slice_offset: int, slice_size: int): y_out = torch.empty_like(y) return y_out diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index a3fda61036..a09f838381 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -20,15 +20,20 @@ import numpy as np import torch +import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, AscendAttentionMetadataBuilder, AscendAttentionState, AscendMetadata) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d) @@ -88,25 +93,31 @@ class AscendTorchairMetadata(AscendMetadata): decode: Optional[AscendDecodeMetadata] = None + enable_dbo_across_dp: bool = False + class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder): - def __init__(self, runner): - super().__init__(runner) + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(vllm_config, device) + self.max_num_blocks_per_req = cdiv( + self.model_config.max_model_len, + self.vllm_config.cache_config.block_size) + self.max_blocks = (self.model_config.max_model_len + + self.vllm_config.cache_config.block_size - + 1) // self.vllm_config.cache_config.block_size def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}" - - if isinstance(self.runner.graph_block_tables, np.ndarray): - graph_block_tables = torch.zeros((max_batch_size, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - else: - graph_block_tables = self.runner.graph_block_tables.to( - device=block_tables.device, dtype=block_tables.dtype) + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) num_blocks = block_tables.size(1) if num_blocks <= max_blocks: @@ -118,14 +129,14 @@ def _get_graph_runner_block_tables( max_blocks] = block_tables[:num_seqs, : max_blocks] - return graph_block_tables[:num_seqs, :max_blocks] + return graph_block_tables[:, :max_blocks] def build_torchair_graph_dummy( - self, num_reqs: int, - num_actual_tokens: int) -> AscendTorchairMetadata: - device = self.runner.device - _, max_blocks = self.runner.graph_block_tables.shape - block_table = torch.zeros((num_reqs, max_blocks), + self, common_attn_metadata: TorchairCommonAttentionMetadata + ) -> AscendTorchairMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs + block_table = torch.zeros((num_reqs, self.max_blocks), dtype=torch.int32, device=device) block_table = self._get_graph_runner_block_tables( @@ -150,7 +161,7 @@ def build_torchair_graph_dummy( max_seq_lens=1) attn_metadata = AscendTorchairMetadata( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, block_tables=block_table, query_lens=0, query_start_loc=query_start_loc, @@ -160,52 +171,50 @@ def build_torchair_graph_dummy( decode=decode_metadata) return attn_metadata - def build(self, - num_reqs, - num_actual_tokens, - max_query_len, - enable_dbo_across_dp: bool = False, - is_only_prefill: bool = False, - *args, - **kwargs): - - if 'graph_pad_size' in kwargs: - graph_pad_size = kwargs['graph_pad_size'] - else: - graph_pad_size = -1 # default value - - device = self.runner.device - - block_table = self.runner.input_batch.block_table[0].get_device_tensor( - ) - block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + def build( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + + block_table = common_attn_metadata.block_table_tensor + block_table[:num_reqs, :self.max_num_blocks_per_req] = ( block_table[:num_reqs]) - query_lens = self.runner.query_lens - seq_lens = self.runner.seq_lens_cpu[:num_reqs] - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True) - attn_mask = self.runner.attn_mask + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping_cpu[: + num_actual_tokens].to( + self.device, + non_blocking= + True) + attn_mask = common_attn_metadata.attn_mask - attn_state = self.runner.attn_state + attn_state = common_attn_metadata.attn_state if is_310p() and attn_state == AscendAttentionState.PrefillNoCache: mask_nz = nd_to_nz_2d(attn_mask) attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29) - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] + query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size use_torchair_graph = graph_pad_size > -1 - if self.runner.attn_state in [ + if common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, ]: max_seq_lens = seq_lens.max().item() num_seqs = len(seq_lens) - if use_torchair_graph and self.runner.attn_state in [ + if use_torchair_graph and common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, ]: num_reqs_pad_size = 0 @@ -214,8 +223,8 @@ def build(self, pad_value = 0 num_token_pad_size = graph_pad_size - num_actual_tokens num_reqs_pad_size = ( - graph_pad_size // self.runner.decode_token_per_req - - num_reqs) + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) pad_value = 1 padded_seq_lens = seq_lens.tolist() + [pad_value ] * num_reqs_pad_size @@ -255,11 +264,11 @@ def build(self, query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, - max_query_len=max_query_len, + max_query_len=common_attn_metadata.max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=enable_dbo_across_dp) + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) return attn_metadata diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index a07304b96a..62b1e1e6fe 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -26,7 +26,8 @@ from vllm.logger import logger from vllm_ascend.platform import NPUPlatform -from vllm_ascend.torchair.utils import (check_torchair_cache_exist, +from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, + check_torchair_cache_exist, register_torchair_model, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, @@ -71,8 +72,16 @@ def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): # NOTE: If torchair graph mode and not with_prefill, # we can't skip_attn, it will cause graph recompile. if not with_prefill: + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.actual_seq_lengths_q, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + decode_token_per_req=self.decode_token_per_req, + ) attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_reqs, num_actual_tokens=1) + common_attn_metadata) else: attn_metadata = super()._build_attention_metadata( with_prefill, num_reqs, skip_attn) diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 0a94494cb2..cdc4ba3a15 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -2,6 +2,7 @@ import os import shutil from contextlib import contextmanager, nullcontext +from dataclasses import dataclass import torch @@ -20,6 +21,32 @@ 'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME)) +@dataclass +class TorchairCommonAttentionMetadata: + """ + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. + """ + + num_reqs: int + """Number of requests""" + + num_actual_tokens: int + """Total number of tokens in batch""" + + decode_token_per_req: int + + actual_seq_lengths_q: list[int] + + attn_mask: torch.Tensor = None + + spec_attn_mask: torch.Tensor = None + + graph_pad_size: int = -1 + + @contextmanager def _file_lock(file_descriptor, lock_type): fcntl.flock(file_descriptor, lock_type) diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index 18fb9fda8d..895649327c 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -16,6 +16,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata PADDING_SLOT_ID = -1 @@ -125,12 +126,27 @@ def propose( query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:batch_size + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + + 1], + seq_lens_cpu=self.runner.seq_lens_cpu, + max_query_len=max_query_len, num_reqs=batch_size, num_actual_tokens=num_tokens, - max_query_len=max_query_len, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping_cpu=target_slot_mapping, + positions=target_positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, ) + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata, self.runner.model) if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f4450d352f..e6ca63ad2b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -23,7 +23,6 @@ import os import time import types -import weakref from contextlib import contextmanager, nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast @@ -79,6 +78,7 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) from vllm_ascend.attention.mla_v1 import AscendMLAMetadata +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, DummyCommImpl, MoECommMethod) @@ -215,7 +215,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): use_mla=self.model_config.use_mla, ) self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) + vllm_config, device) self.attn_mask_builder = AttentionMaskBuilder( min(self.model_config.max_model_len, int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype) @@ -228,13 +228,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer]] = None self.actual_seq_lengths_q = [] - self.spec_token_num = 0 self.decode_token_per_req = 1 if self.speculative_config: self.use_spec_decode = True - self.spec_token_num = self.speculative_config.num_speculative_tokens - assert self.spec_token_num > 0 - self.decode_token_per_req = 1 + self.spec_token_num + spec_token_num = self.speculative_config.num_speculative_tokens + assert spec_token_num > 0 + self.decode_token_per_req = 1 + spec_token_num self.actual_seq_lengths_q = [ len for len in range(self.decode_token_per_req, self.max_num_tokens + @@ -798,11 +797,25 @@ def get_eagle_atten_dict( # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - attn_metadata_i = self.attn_metadata_builder.build( + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + num_actual_tokens=total_num_scheduled_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=self.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping_cpu=self.slot_mapping_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + decode_token_per_req=self.decode_token_per_req, ) + attn_metadata_i = self.attn_metadata_builder.build( + common_attn_metadata, self.get_model()) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1168,8 +1181,6 @@ def _process_reqs( attn_state=attn_state) self.attn_state = attn_state # type: ignore - extra_builder_kwargs = {} - self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc[:num_reqs + 1].copy_( @@ -1186,45 +1197,44 @@ def _process_reqs( ] is_only_prefill = bool(np.all(num_valid_tokens != 1)) - extra_builder_kwargs['is_only_prefill'] = is_only_prefill enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), - attn_state, - total_num_scheduled_tokens) (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, enable_dbo) = self._get_forward_metadata_across_dp_and_pad( total_num_scheduled_tokens, with_prefill, enable_dbo) - extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp if self.torchair_graph_enabled and not with_prefill: self.graph_pad_size = padded_num_tokens_across_dp - extra_builder_kwargs[ - 'graph_pad_size'] = self.graph_pad_size # type: ignore else: self.graph_pad_size = -1 - + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=self.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping_cpu=self.slot_mapping_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=is_only_prefill, + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + ) + attn_metadata = self.attn_metadata_builder.build( + common_attn_metadata, self.model) if self.vllm_config.model_config.use_mla: - extra_builder_kwargs[ - "query_start_loc"] = self.query_start_loc[:num_reqs + 1] - attn_metadata = self.attn_metadata_builder.build( # type: ignore - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - **extra_builder_kwargs, - ) attn_metadata.num_input_tokens = num_input_tokens - else: - attn_metadata = self.attn_metadata_builder.build( # type: ignore - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - **extra_builder_kwargs, - ) # Prepare input_ids token_indices = (positions_np + diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index f4597de23f..949314303b 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -16,7 +16,9 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP +from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata from vllm_ascend.utils import ProfileExecuteDuration @@ -88,7 +90,7 @@ def prepare_inputs( # FIXME(woosuk): Avoid synchronization. num_tokens = cu_num_tokens[-1].item() - token_indices = torch.empty( + token_indices = torch.zeros( num_tokens, dtype=torch.int32, device=cu_num_tokens.device, @@ -136,9 +138,6 @@ def propose( # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] if token_indices is not None and self.runner.torchair_graph_enabled: last_token_indices = token_indices - else: - seq_lens = target_positions[last_token_indices] + 1 - seq_lens = seq_lens.cpu() self.input_ids[last_token_indices] = next_token_ids @@ -155,23 +154,36 @@ def propose( # input_batch=self.runner.input_batch, # scheduler_output=self.runner.scheduler_output, # ) - extra_builder_kwargs = {} - is_running_torchair = self.runner.torchair_graph_enabled and \ not self.runner.with_prefill if is_running_torchair: - extra_builder_kwargs['graph_pad_size'] = self.runner.graph_pad_size num_input_tokens = self.runner.graph_pad_size else: num_input_tokens = num_tokens - attn_metadata = self.runner.attn_metadata_builder.build( + seq_lens = target_positions[last_token_indices] + 1 + seq_lens = seq_lens.int() + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=cu_num_tokens[:batch_size + 1], + query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), + seq_lens_cpu=seq_lens.cpu(), num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, - query_start_loc=cu_num_tokens, - **extra_builder_kwargs) + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping_cpu=target_slot_mapping, + positions=target_positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + graph_pad_size=self.runner.graph_pad_size, + decode_token_per_req=self.runner.decode_token_per_req, + ) + attn_metadata = self.runner.attn_metadata_builder.build( + common_attn_metadata, self.runner.get_model()) self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states @@ -281,8 +293,16 @@ def dummy_run(self, if skip_attn: attn_metadata = None else: + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + decode_token_per_req=self.runner.decode_token_per_req, + ) attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_reqs, num_actual_tokens=1) + common_attn_metadata) input_ids = self.input_ids[:num_tokens] positions = self.positions[:num_tokens]