Skip to content

Commit 1a7cbdb

Browse files
A5 support reshape and cache in CP situation (#7636)
### What this PR does / why we need it? In the A5 scenario, the CP is supported. The A5 reshape and cache operators need to go through the aclnn operator Therefore, the routing of DeviceAdaptor is added. In addition, the input of the A5 aclnn operator should be continuous. There are some non-contiguous operations, such as slicing with intervals. `slot_mapping = attn_metadata.slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]`, where `slot_mapping` is non-contiguous and needs to be contiguous.Therefore, the continuity of key, value, and slot-mapping is fixed. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@ed359c4 --------- Signed-off-by: lenghuixing0330 <2531948770@qq.com>
1 parent dbf1348 commit 1a7cbdb

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

vllm_ascend/attention/context_parallel/attention_cp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
split_decodes_and_prefills,
5050
)
5151
from vllm_ascend.compilation.acl_graph import get_graph_params, update_graph_params_workspaces
52+
from vllm_ascend.device.device_op import DeviceOperator
5253
from vllm_ascend.utils import cp_chunkedprefill_comm_stream, weak_ref_tensors
5354

5455

@@ -752,12 +753,12 @@ def reshape_and_cache(
752753

753754
if has_decode:
754755
slot_mapping = attn_metadata.slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]
755-
torch_npu._npu_reshape_and_cache(
756+
DeviceOperator.reshape_and_cache(
756757
key=key[:num_decode_tokens],
757758
value=value[:num_decode_tokens],
758759
key_cache=self.key_cache,
759760
value_cache=self.value_cache,
760-
slot_indices=slot_mapping,
761+
slot_mapping=slot_mapping,
761762
)
762763

763764
if has_prefill:
@@ -784,12 +785,12 @@ def reshape_and_cache(
784785
slot_mapping = attn_metadata.slot_mapping[
785786
self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded
786787
]
787-
torch_npu._npu_reshape_and_cache(
788+
DeviceOperator.reshape_and_cache(
788789
key=prefill_key,
789790
value=prefill_value,
790791
key_cache=self.key_cache,
791792
value_cache=self.value_cache,
792-
slot_indices=slot_mapping,
793+
slot_mapping=slot_mapping,
793794
)
794795
if self.is_kv_producer:
795796
attn_metadata.reshape_cache_event.record()

vllm_ascend/device/device_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,11 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
204204
@classmethod
205205
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
206206
torch_npu.npu_scatter_pa_kv_cache(
207-
key=key, value=value.contiguous(), key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping
207+
key=key.contiguous(),
208+
value=value.contiguous(),
209+
key_cache=key_cache,
210+
value_cache=value_cache,
211+
slot_mapping=slot_mapping.contiguous(),
208212
)
209213

210214
@staticmethod

0 commit comments

Comments
 (0)