Skip to content

Commit 70f0763

Browse files
[MM][Bugfix] Add error log for VL models when enabling FLASHCOMM (#4222)
### What this PR does / why we need it? Add error log for VL models when enabling `VLLM_ASCEND_ENABLE_FLASHCOMM1=1` or `VLLM_ASCEND_ENABLE_FLASHCOMM=1` (for backward compatibility). This is a temporary fix for #4132. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: shen-shanshan <[email protected]>
1 parent c94b38c commit 70f0763

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

vllm_ascend/platform.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
3333
delete_torchair_cache_file)
3434
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p,
35-
update_aclgraph_sizes,
35+
is_vl_model, update_aclgraph_sizes,
3636
update_default_aclgraph_sizes)
3737

3838
if TYPE_CHECKING:
@@ -303,6 +303,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
303303
vllm_config.scheduler_config)
304304
vllm_config.scheduler_config = recompute_scheduler_config
305305

306+
if is_vl_model(vllm_config):
307+
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))) or \
308+
bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))):
309+
raise ValueError(
310+
"Currently, VL models doesn't support "
311+
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
312+
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0.")
313+
306314
@classmethod
307315
def get_attn_backend_cls(
308316
cls,

vllm_ascend/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
_DEFAULT_BUFFER_SIZE = 200
5858
_MIN_DP_BUFFER_SIZE = 50
5959
_IS_MOE_MODEL = None
60+
_IS_VL_MODEL = None
6061
_ENABLE_SP = None
6162
_HAS_LAYER_IDX = None
6263
_ENABLE_NZ = None
@@ -696,6 +697,15 @@ def _is_contain_expert(config: Any):
696697
return False
697698

698699

700+
def is_vl_model(vllm_config: VllmConfig):
701+
"""Checks if the model is a VL model by config"""
702+
global _IS_VL_MODEL
703+
if _IS_VL_MODEL is None:
704+
model_configs = vllm_config.model_config.hf_config.to_dict()
705+
_IS_VL_MODEL = "VL" in model_configs["architectures"][0]
706+
return _IS_VL_MODEL
707+
708+
699709
def weak_ref_tensor(tensor: Any) -> Any:
700710
"""
701711
Create a weak reference to a tensor.

0 commit comments

Comments
 (0)