Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions vllm_ascend/attention/context_parallel/attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
split_decodes_and_prefills,
)
from vllm_ascend.compilation.acl_graph import get_graph_params, update_graph_params_workspaces
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.utils import cp_chunkedprefill_comm_stream, weak_ref_tensors


Expand Down Expand Up @@ -759,12 +760,12 @@ def reshape_and_cache(

if has_decode:
slot_mapping = attn_metadata.slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]
torch_npu._npu_reshape_and_cache(
DeviceOperator.reshape_and_cache(
key=key[:num_decode_tokens],
value=value[:num_decode_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping,
slot_mapping=slot_mapping,
)

if has_prefill:
Expand All @@ -791,12 +792,12 @@ def reshape_and_cache(
slot_mapping = attn_metadata.slot_mapping[
self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded
]
torch_npu._npu_reshape_and_cache(
DeviceOperator.reshape_and_cache(
key=prefill_key,
value=prefill_value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping,
slot_mapping=slot_mapping,
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/device/device_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
@classmethod
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
torch_npu.npu_scatter_pa_kv_cache(
key=key, value=value.contiguous(), key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping
key=key.contiguous(),
value=value.contiguous(),
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=slot_mapping.contiguous(),
)

@staticmethod
Expand Down
Loading