Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(
Expand Down
84 changes: 67 additions & 17 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def split_metadata_for_multistream(
class AscendMLAMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
AttentionCGSupport.UNIFORM_BATCH
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
Expand Down Expand Up @@ -389,6 +389,8 @@ def build(

decode_metadata = None
if num_decodes > 0:
cos = common_attn_metadata.cos
sin = common_attn_metadata.sin
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
Expand All @@ -397,21 +399,45 @@ def build(
block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist()

cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
if cos is None and sin is None:
cos = self.cos_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
1).unsqueeze(2)
sin[:num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(
1).unsqueeze(2)

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...])

return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
Expand All @@ -431,6 +457,29 @@ def build(
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)

def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state in {
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
}:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
model=model,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state"
)

attn_metadata.attn_state = attn_state
return attn_metadata


class DecodeMLAPreprocessResult(NamedTuple):
ql_nope: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -814,7 +863,8 @@ def _forward_decode(

if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
] and self.speculative_config is not None:
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
input_layout = "TND"
Expand Down
4 changes: 4 additions & 0 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class AscendCommonAttentionMetadata:

graph_pad_size: int = -1

# NOTE: This is a temporary solution for rotary embedding in MLA
cos: torch.Tensor = None
sin: torch.Tensor = None


def split_decodes_and_prefills(
common_attn_metadata: AscendCommonAttentionMetadata,
Expand Down
54 changes: 54 additions & 0 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,60 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
event.record(update_stream)


def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, workspace, attn_output,
softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (
runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)

torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=workspace,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should retrieve workspace from graph params.

out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(update_stream)

event.record(update_stream)


@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
Expand Down
56 changes: 50 additions & 6 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_graph_params,
update_attn_params)
update_attn_params,
update_mla_attn_params)
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
D2DExpertWeightLoader
Expand Down Expand Up @@ -351,6 +352,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.speculative_config.method, self.vllm_config,
self.device, self)
self.rejection_sampler = AscendRejectionSampler()
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))

# Persistent batch.
self.input_ids = torch.zeros(self.max_num_tokens,
Expand All @@ -369,6 +373,25 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
dtype=torch.int32,
device=self.device)

if self.vllm_config.model_config.use_mla and \
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos = torch.ones(self.max_num_reqs,
1,
1,
rope_dim,
dtype=self.dtype,
device=self.device)
self.sin = torch.zeros(self.max_num_reqs,
1,
1,
rope_dim,
dtype=self.dtype,
device=self.device)
else:
self.cos = None
self.sin = None

self.uses_mrope = self.model_config.uses_mrope
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
Expand Down Expand Up @@ -1518,6 +1541,8 @@ def _prepare_inputs(
max_query_len=max_num_scheduled_tokens,
graph_pad_size=self.graph_pad_size,
decode_token_per_req=self.decode_token_per_req,
cos=self.cos,
sin=self.sin,
)

if self.speculative_config and \
Expand Down Expand Up @@ -1547,7 +1572,7 @@ def _prepare_inputs(
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
model=self.model,
model=self.get_model(),
**extra_attn_metadata_args)

if self.vllm_config.model_config.use_mla:
Expand Down Expand Up @@ -1582,8 +1607,14 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,

forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
update_attn_params(self.update_stream, forward_context,
positions.shape[0])
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
positions.shape[0],
self.speculative_config)
else:
update_attn_params(self.update_stream, forward_context,
positions.shape[0])

if get_forward_context().sp_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
Expand Down Expand Up @@ -2285,7 +2316,10 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
block_table_tensor = self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc=torch.tensor(
[0] + self.actual_seq_lengths_q[:num_reqs],
device=self.device,
dtype=torch.int32),
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu,
Expand All @@ -2296,17 +2330,27 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=self.slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
max_query_len=max_query_len,
decode_token_per_req=self.decode_token_per_req,
cos=self.cos,
sin=self.sin,
)
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
self.speculative_config.method == "deepseek_mtp":
attn_state = AscendAttentionState.SpecDecoding

for attn_group in self.attn_groups[kv_cache_group_id]:
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata)
common_attn_metadata, attn_state, self.get_model())
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i

Expand Down
Loading