Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/offline_inference_npu_long_seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import time
import argparse

from vllm import LLM, SamplingParams

os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument('--input_len', type=int, default=1024)
parser.add_argument('--output_len', type=int, default=128)
parser.add_argument('--bs', type=int, default=1)
parser.add_argument('--model_path', type=str, default="deepseek-ai/DeepSeek-V2-Lite")
parser.add_argument('--tp', type=int, default=2)
parser.add_argument('--cp', type=int, default=2)
parser.add_argument('--dcp', type=int, default=2)
parser.add_argument('--iter_times', type=int, default=1)

args = parser.parse_args()

prompts = [
"The capital of France is",
"Hello, my name is Tom, I am",
"The president of United States is",
"AI future is? What do you think about it? Can you give me some information or any thing you want?"
]

sampling_params = SamplingParams(temperature = 0.8, top_p = 0.95, max_tokens=args.output_len)
llm = LLM(
model=args.model_path,
trust_remote_code=True,
enforce_eager=True,
tensor_parallel_size=args.tp,
context_parallel_size=args.cp,
decode_context_parallel_size=args.dcp,
enable_prefix_caching=False,
enable_expert_parallel=True,
enable_chunked_prefill=False,
max_num_batched_tokens=args.input_len + 138,
max_model_len=args.input_len + args.output_len + 138,
additional_config={"ascend_scheduler_config": {"enabled": True}},
max_num_seqs=1,
block_size=128,
gpu_memory_utilization=0.9
)

t0 = time.time()
for _ in range(args.iter_times):
outputs = llm.generate(prompts, sampling_params)
t1 = time.time()
print(f"TTFT: {(t1 - t0) * 1000 / (args.iter_times * args.bs)} ms")

for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"req_num: {i}\nGenerated text: {generated_text!r}")
444 changes: 425 additions & 19 deletions vllm_ascend/attention/mla_v1.py

Large diffs are not rendered by default.

37 changes: 35 additions & 2 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,44 @@
from dataclasses import dataclass
from typing import Any, List

from typing import Any, List, Optional
import torch
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context


@dataclass
class AscendCommonLongSequenceMetadata:
cp_kv_recover_idx: torch.Tensor = None

num_actual_tokens_cp_full: Optional[int] = None

num_computed_tokens_of_cp_sp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None

q_head_idx_tensor: torch.Tensor = None

q_tail_idx_tensor: torch.Tensor = None

kv_with_q_head_nomask_idx_tensor: torch.Tensor = None

kv_with_q_head_mask_idx_tensor: torch.Tensor = None

kv_with_q_tail_nomask_idx_tensor: torch.Tensor = None

kv_with_q_tail_mask_idx_tensor: torch.Tensor = None

attn_mask_seqlens: torch.Tensor = None

head_attn_nomask_seqlens: torch.Tensor = None

tail_attn_nomask_seqlens: torch.Tensor = None

q_full_idx: torch.Tensor = None

cp_prefill_mask: torch.Tensor = None


@dataclass
class AscendCommonAttentionMetadata:
"""
Expand Down Expand Up @@ -63,6 +94,8 @@ class AscendCommonAttentionMetadata:

graph_pad_size: int = -1

common_long_seq_metadata: Optional[AscendCommonLongSequenceMetadata] = None


def split_decodes_and_prefills(
common_attn_metadata: AscendCommonAttentionMetadata,
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Type, Union

from vllm.config import SchedulerConfig
from vllm.distributed import get_dcp_group


@dataclass
Expand Down
13 changes: 10 additions & 3 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import context_parallel_enable

# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
Expand Down Expand Up @@ -51,9 +52,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
if context_parallel_enable():
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.context_parallel_size *
parallel_config.tensor_parallel_size)
else:
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
global _MC2
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@
# caused by the initialization of the Mooncake connector.
"PHYSICAL_DEVICES":
lambda: os.getenv("PHYSICAL_DEVICES", None),
# Decide whether we should enable CP parallelism.
"VLLM_ASCEND_ENABLE_CP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CP", '0')))
}

# end-env-vars-definition
Expand Down
32 changes: 0 additions & 32 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,38 +128,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
ascend_scheduler_config = ascend_config.ascend_scheduler_config
if vllm_version_is("0.10.2"):
structured_outputs_config = vllm_config.decoding_config
else:
structured_outputs_config = vllm_config.structured_outputs_config

if model_config is not None and not model_config.use_mla:
logger.info(
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
"as the performance of operators supporting this feature "
"functionality is currently suboptimal.")
if not model_config.is_multimodal_model and \
structured_outputs_config.backend == "auto" and \
not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \
not scheduler_config.send_delta_data and \
scheduler_config.policy == "fcfs":
ascend_scheduler_config.enabled = True
chunked_prefill_enabled_in_ascend_scheduler = getattr(
ascend_scheduler_config, "enable_chunked_prefill", False)
if chunked_prefill_enabled_in_ascend_scheduler:
logger.warning(
"Chunked prefill feature is enabled in ascend_scheduler,"
"but note that the operator supporting this feature "
"would lead to performance degradation.")
# In this situation, max_num_batched_tokens would have been rewritten.
# So we must make sure max_num_batched_tokens is not smaller than max_model_len.
if (scheduler_config.max_num_batched_tokens
< scheduler_config.max_model_len
and not chunked_prefill_enabled_in_ascend_scheduler):
scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len

kv_cache_dtype = vllm_config.additional_config.get(
"kv_cache_dtype", None)
if kv_cache_dtype is not None:
Expand Down
4 changes: 4 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,10 @@ def enable_sp() -> bool:
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM)


def context_parallel_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_CP


def is_moe_model(vllm_config: VllmConfig):
config = vllm_config.model_config.hf_config
return any('experts' in key.lower() for key in config.to_dict())
Expand Down
46 changes: 41 additions & 5 deletions vllm_ascend/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from vllm.distributed import get_dcp_group
from vllm.utils import cdiv

from vllm_ascend.utils import context_parallel_enable

if context_parallel_enable():
from vllm.distributed import get_cp_group


class BlockTable:

Expand Down Expand Up @@ -80,12 +85,16 @@ def __init__(self,
dtype=torch.int64,
device=self.device)
try:
self.cp_world_size = get_cp_group().world_size if context_parallel_enable() else 1
self.cp_rank = get_cp_group().rank_in_group if self.cp_world_size > 1 else 0
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.cp_world_size = 1
self.cp_rank = 0
self.kernel_sizes = kernel_sizes

def append_row(
Expand Down Expand Up @@ -132,14 +141,14 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
# here because M (max_model_len) is not necessarily divisible by
# block_size.

if self.dcp_world_size > 1:
if self.dcp_world_size * self.cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:

# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
virtual_block_size = self.block_size * self.dcp_world_size * self.cp_world_size

# IMPORTANT: In hybrid mode, positions are in logical block space,
# but we need to map them to the correct logical block table indices
Expand All @@ -157,9 +166,11 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank
mask = (virtual_block_offsets %
(self.dcp_world_size * self.cp_world_size) == self.current_rank)
# Calculate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
Expand Down Expand Up @@ -249,9 +260,11 @@ def __init__(self,
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
cp_world_size = get_cp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
cp_world_size = 1

if kernel_sizes is None:
kernel_sizes = [[0]] * len(block_sizes)
Expand All @@ -267,7 +280,7 @@ def __init__(self,
self.block_tables = [
BlockTable(
block_size, max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size),
max(cdiv(max_model_len, block_size * dcp_world_size * cp_world_size),
1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device, kernel_size_list)
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
Expand Down Expand Up @@ -303,6 +316,29 @@ def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)

def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]:
"Splits computed token counts across dcp and sp dimensions for distributed allocation."
self.cp_world_size = get_cp_group().world_size if context_parallel_enable() else 1
self.dcp_world_size = get_dcp_group().world_size
num_requests = len(num_computed_tokens)
num_computed_tokens_of_dcp_sp = [[
[0] * self.dcp_world_size for _ in range(self.cp_world_size)
] for _ in range(num_requests)]
total_ranks = self.cp_world_size * self.dcp_world_size
for req_idx in range(num_requests):
total_tokens = num_computed_tokens[req_idx]
if total_tokens <= 0:
continue
base = int(total_tokens) // total_ranks
remainder = int(total_tokens) % total_ranks
for rank_idx in range(total_ranks):
cp_idx = rank_idx // self.dcp_world_size
sp_idx = rank_idx % self.dcp_world_size
num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] = base
if rank_idx < remainder:
num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] += 1
return num_computed_tokens_of_dcp_sp

def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
Expand Down
Loading
Loading