diff --git a/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py b/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py index 68736c0928..4e3fcad629 100644 --- a/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py @@ -78,7 +78,12 @@ def test_mtp_correctness( }, max_model_len=256, gpu_memory_utilization=0.8, - enforce_eager=True) + enforce_eager=True, + additional_config={ + "ascend_scheduler_config": { + "enabled": True + }, + }) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 misses = 0 diff --git a/tests/singlecard/core/test_ascend_scheduler_e2e.py b/tests/singlecard/core/test_ascend_scheduler_e2e.py index 668dafced9..70272b3da5 100644 --- a/tests/singlecard/core/test_ascend_scheduler_e2e.py +++ b/tests/singlecard/core/test_ascend_scheduler_e2e.py @@ -18,6 +18,7 @@ def model() -> LLM: MODEL, enforce_eager=True, enable_prefix_caching=True, + max_model_len=200, max_num_batched_tokens=200, max_num_seqs=3, additional_config={"ascend_scheduler_config": { diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 9d99a04cd8..499030b4b7 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -203,6 +203,13 @@ def check_ascend_config(vllm_config, enforce_eager): "Ascend scheduler is only supported for V1 Engine.") # for v1 engine else: + # TODO(yexiong): remove this verification after mtp model supports original vllm scheduler + if (not ascend_config.ascend_scheduler_config.enabled + and vllm_config.speculative_config + and vllm_config.speculative_config.method == 'deepseek_mtp'): + raise NotImplementedError( + "Currently deepseek MTP model is only supported for ascend scheduler." + ) # for eager mode if enforce_eager: # torchair_graph cannot be enabled with eager mode. diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 4a4131ecd7..4ee02e7ed4 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -55,6 +55,16 @@ def __post_init__(self) -> None: self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens self.chunked_prefill_enabled = self.enable_chunked_prefill + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): + raise ValueError( + "Ascend scheduler is enabled without chunked prefill feature. " + f"Argument max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") if self.policy != "fcfs": raise NotImplementedError( f"currently AscendScheduler only supports fcfs policy, got {self.policy}" diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 8eab3a14aa..9792f54446 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -27,6 +27,7 @@ from vllm.logger import logger from vllm.platforms import Platform, PlatformEnum +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, check_torchair_cache_exist, @@ -147,6 +148,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: check_ascend_config(vllm_config, enforce_eager) + if vllm_config.speculative_config and envs_ascend.VLLM_ASCEND_ENABLE_DBO: + raise ValueError( + "DBO and mtp can't work at the same time. Please `export VLLM_ASCEND_ENABLE_DBO=0`" + ) + if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION: logger.info("Compilation disabled, using eager mode by default") compilation_config.level = CompilationLevel.NO_COMPILATION @@ -154,7 +160,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: logger.warning( "NPU does not support %s compilation level. Setting level to NO_COMPILATION", compilation_config.level) - compilation_config.level = CompilationLevel.NO_COMPILATION + compilation_config.level = CompilationLevel.NfvO_COMPILATION elif ascend_config.torchair_graph_config.enabled: logger.info( "Torchair compilation enabled on NPU. Setting level to NO_COMPILATION" diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 347f9126a5..c622e47a9a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -219,9 +219,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.spec_token_num = 0 self.decode_token_per_req = 1 if self.speculative_config: - if envs_ascend.VLLM_ASCEND_ENABLE_DBO: - raise NotImplementedError( - "DBO and mtp can't work at the same currently") self.use_spec_decode = True self.spec_token_num = self.speculative_config.num_speculative_tokens assert self.spec_token_num > 0