Skip to content

Commit c3a2c6a

Browse files
authored
[MM][Core] Decouple ViT backend from LM backend (#27061)
Signed-off-by: Roger Wang <[email protected]>
1 parent 72f431e commit c3a2c6a

File tree

16 files changed

+230
-17
lines changed

16 files changed

+230
-17
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
6+
from vllm.attention.backends.registry import _Backend
7+
from vllm.config.multimodal import MultiModalConfig
8+
9+
10+
def test_mm_encoder_attn_backend_str_conversion():
11+
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
12+
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
13+
14+
15+
def test_mm_encoder_attn_backend_invalid():
16+
with pytest.raises(ValueError):
17+
MultiModalConfig(mm_encoder_attn_backend="not_a_backend")
18+
19+
20+
def test_mm_encoder_attn_backend_hash_updates():
21+
base_hash = MultiModalConfig().compute_hash()
22+
overridden_hash = MultiModalConfig(
23+
mm_encoder_attn_backend=_Backend.FLASH_ATTN
24+
).compute_hash()
25+
assert base_hash != overridden_hash

vllm/attention/layer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.attention.selector import get_attn_backend
1717
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
1818
from vllm.config import CacheConfig, get_current_vllm_config
19+
from vllm.config.multimodal import MultiModalConfig
1920
from vllm.config.vllm import VllmConfig
2021
from vllm.distributed.kv_transfer import (
2122
get_kv_transfer_group,
@@ -443,6 +444,7 @@ def __init__(
443444
# This has no effect, it is only here to make it easier to swap
444445
# between Attention and MultiHeadAttention
445446
prefix: str = "",
447+
multimodal_config: MultiModalConfig | None = None,
446448
) -> None:
447449
super().__init__()
448450
self.num_heads = num_heads
@@ -462,7 +464,14 @@ def __init__(
462464
dtype = torch.get_default_dtype()
463465

464466
# Determine the attention backend
465-
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
467+
attn_backend_override = None
468+
if multimodal_config is not None:
469+
attn_backend_override = multimodal_config.mm_encoder_attn_backend
470+
backend = get_vit_attn_backend(
471+
head_size=head_size,
472+
dtype=dtype,
473+
attn_backend_override=attn_backend_override,
474+
)
466475

467476
# Some auto-selected backends can be upgraded
468477
# to upstream flash attention if available.

vllm/config/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@
5050

5151
import vllm.model_executor.layers.quantization as me_quant
5252
import vllm.model_executor.models as me_models
53+
from vllm.attention.backends.registry import _Backend
5354
from vllm.config.load import LoadConfig
5455
from vllm.config.parallel import ParallelConfig
5556
from vllm.model_executor.layers.quantization import QuantizationMethods
5657
from vllm.v1.sample.logits_processor import LogitsProcessor
5758
else:
5859
PretrainedConfig = Any
5960

61+
_Backend = Any
6062
me_quant = LazyLoader(
6163
"model_executor", globals(), "vllm.model_executor.layers.quantization"
6264
)
@@ -307,6 +309,7 @@ class ModelConfig:
307309
mm_processor_cache_type: InitVar[MMCacheType | None] = None
308310
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
309311
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
312+
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None
310313
interleave_mm_strings: InitVar[bool | None] = None
311314
skip_mm_profiling: InitVar[bool | None] = None
312315
video_pruning_rate: InitVar[float | None] = None
@@ -424,6 +427,7 @@ def __post_init__(
424427
mm_processor_cache_type: MMCacheType | None,
425428
mm_shm_cache_max_object_size_mb: int | None,
426429
mm_encoder_tp_mode: MMEncoderTPMode | None,
430+
mm_encoder_attn_backend: _Backend | str | None,
427431
interleave_mm_strings: bool | None,
428432
skip_mm_profiling: bool | None,
429433
video_pruning_rate: float | None,
@@ -733,6 +737,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
733737
mm_processor_cache_type=mm_processor_cache_type,
734738
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
735739
mm_encoder_tp_mode=mm_encoder_tp_mode,
740+
mm_encoder_attn_backend=mm_encoder_attn_backend,
736741
interleave_mm_strings=interleave_mm_strings,
737742
skip_mm_profiling=skip_mm_profiling,
738743
video_pruning_rate=video_pruning_rate,

vllm/config/multimodal.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33

44
import hashlib
55
from collections.abc import Mapping
6-
from typing import Any, Literal, TypeAlias
6+
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
77

88
from pydantic import ConfigDict, Field, field_validator, model_validator
99
from pydantic.dataclasses import dataclass
1010

1111
from vllm.config.utils import config
1212

13+
if TYPE_CHECKING:
14+
from vllm.attention.backends.registry import _Backend
15+
else:
16+
_Backend = Any
17+
1318

1419
@dataclass
1520
class BaseDummyOptions:
@@ -112,6 +117,10 @@ class MultiModalConfig:
112117
DP (which is controlled by `--data-parallel-size`).
113118
This is only supported on a per-model basis and falls back to
114119
`"weights"` if the encoder does not support DP."""
120+
mm_encoder_attn_backend: _Backend | None = None
121+
"""Optional override for the multi-modal encoder attention backend when
122+
using vision transformers. Accepts any value from
123+
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
115124
interleave_mm_strings: bool = False
116125
"""Enable fully interleaved support for multimodal prompts, while using
117126
--chat-template-content-format=string."""
@@ -148,6 +157,29 @@ def _validate_limit_per_prompt(
148157
value[k] = BaseDummyOptions(**v)
149158
return value
150159

160+
@field_validator("mm_encoder_attn_backend", mode="before")
161+
@classmethod
162+
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None:
163+
from vllm.attention.backends.registry import (
164+
_Backend as BackendEnum,
165+
)
166+
from vllm.attention.backends.registry import (
167+
backend_name_to_enum,
168+
)
169+
170+
if value is None or isinstance(value, BackendEnum):
171+
return value
172+
173+
if isinstance(value, str):
174+
candidate = backend_name_to_enum(value.upper())
175+
if candidate is not None:
176+
return candidate
177+
178+
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
179+
raise ValueError(
180+
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
181+
)
182+
151183
@model_validator(mode="after")
152184
def _validate_multimodal_config(self):
153185
if self.mm_processor_cache_type != "shm" and (
@@ -172,9 +204,11 @@ def compute_hash(self) -> str:
172204
excluding anything before input ids/embeddings and after
173205
the final hidden states.
174206
"""
175-
# no factors to consider.
176-
# this config will not affect the computation graph.
177-
factors: list[Any] = []
207+
factors: list[Any] = [
208+
self.mm_encoder_attn_backend.name
209+
if self.mm_encoder_attn_backend is not None
210+
else None
211+
]
178212
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
179213
return hash_str
180214

vllm/engine/arg_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from typing_extensions import TypeIs, deprecated
3333

3434
import vllm.envs as envs
35+
from vllm.attention.backends.registry import _Backend
3536
from vllm.config import (
3637
CacheConfig,
3738
CompilationConfig,
@@ -451,6 +452,9 @@ class EngineArgs:
451452
MultiModalConfig.mm_shm_cache_max_object_size_mb
452453
)
453454
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
455+
mm_encoder_attn_backend: _Backend | str | None = (
456+
MultiModalConfig.mm_encoder_attn_backend
457+
)
454458
io_processor_plugin: str | None = None
455459
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
456460
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
@@ -914,6 +918,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
914918
multimodal_group.add_argument(
915919
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
916920
)
921+
multimodal_group.add_argument(
922+
"--mm-encoder-attn-backend",
923+
**multimodal_kwargs["mm_encoder_attn_backend"],
924+
)
917925
multimodal_group.add_argument(
918926
"--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
919927
)
@@ -1160,6 +1168,7 @@ def create_model_config(self) -> ModelConfig:
11601168
mm_processor_cache_type=self.mm_processor_cache_type,
11611169
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
11621170
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1171+
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
11631172
pooler_config=self.pooler_config,
11641173
override_pooler_config=self.override_pooler_config,
11651174
logits_processor_pattern=self.logits_processor_pattern,

vllm/model_executor/models/dots_ocr.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def __init__(
256256
quant_config: QuantizationConfig | None = None,
257257
prefix: str = "",
258258
use_data_parallel: bool = False,
259+
attn_backend_override: _Backend | None = None,
259260
) -> None:
260261
super().__init__()
261262

@@ -288,7 +289,9 @@ def __init__(
288289
)
289290
# Select attention backend
290291
self.attn_backend = get_vit_attn_backend(
291-
self.hidden_size_per_attention_head, torch.get_default_dtype()
292+
self.hidden_size_per_attention_head,
293+
torch.get_default_dtype(),
294+
attn_backend_override=attn_backend_override,
292295
)
293296
self.use_upstream_fa = False
294297

@@ -510,6 +513,7 @@ def __init__(
510513
quant_config: QuantizationConfig | None = None,
511514
prefix: str = "",
512515
use_data_parallel: bool = False,
516+
attn_backend_override: _Backend | None = None,
513517
):
514518
super().__init__()
515519

@@ -521,6 +525,7 @@ def __init__(
521525
quant_config=quant_config,
522526
prefix=f"{prefix}.attn",
523527
use_data_parallel=use_data_parallel,
528+
attn_backend_override=attn_backend_override,
524529
)
525530
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
526531
self.mlp = DotsSwiGLUFFN(
@@ -561,6 +566,7 @@ def __init__(
561566
require_post_norm: bool | None = None,
562567
prefix: str = "",
563568
use_data_parallel: bool = False,
569+
attn_backend_override: _Backend | None = None,
564570
) -> None:
565571
super().__init__()
566572
self.config = config
@@ -571,7 +577,9 @@ def __init__(
571577
head_dim = config.embed_dim // config.num_attention_heads
572578
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
573579
self.attn_backend = get_vit_attn_backend(
574-
head_size=head_dim, dtype=torch.get_default_dtype()
580+
head_size=head_dim,
581+
dtype=torch.get_default_dtype(),
582+
attn_backend_override=attn_backend_override,
575583
)
576584
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
577585
torch.get_default_dtype()
@@ -591,6 +599,7 @@ def __init__(
591599
quant_config=quant_config,
592600
prefix=f"{prefix}.blocks.{i}",
593601
use_data_parallel=use_data_parallel,
602+
attn_backend_override=attn_backend_override,
594603
)
595604
for i in range(num_layers)
596605
]
@@ -750,11 +759,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
750759
self.config.vision_config = vision_config
751760
else:
752761
vision_config = self.config.vision_config
762+
attn_backend_override = (
763+
multimodal_config.mm_encoder_attn_backend
764+
if multimodal_config is not None
765+
else None
766+
)
753767
self.vision_tower = DotsVisionTransformer(
754768
vision_config,
755769
quant_config=self.quant_config,
756770
prefix=maybe_prefix(prefix, "vision_tower"),
757771
use_data_parallel=self.use_data_parallel,
772+
attn_backend_override=attn_backend_override,
758773
)
759774
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
760775
vllm_config=vllm_config,

vllm/model_executor/models/ernie45_vl.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def __init__(
164164
projection_size: int,
165165
quant_config: QuantizationConfig | None = None,
166166
prefix: str = "",
167+
attn_backend_override: _Backend | None = None,
167168
) -> None:
168169
super().__init__()
169170
# Per attention head and per partition values.
@@ -196,6 +197,7 @@ def __init__(
196197
self.attn_backend = get_vit_attn_backend(
197198
head_size=self.hidden_size_per_attention_head,
198199
dtype=torch.get_default_dtype(),
200+
attn_backend_override=attn_backend_override,
199201
)
200202

201203
self.use_upstream_fa = False
@@ -367,6 +369,7 @@ def __init__(
367369
norm_layer: Callable[[int], nn.Module] | None = None,
368370
quant_config: QuantizationConfig | None = None,
369371
prefix: str = "",
372+
attn_backend_override: _Backend | None = None,
370373
) -> None:
371374
super().__init__()
372375

@@ -382,6 +385,7 @@ def __init__(
382385
projection_size=dim,
383386
quant_config=quant_config,
384387
prefix=f"{prefix}.attn",
388+
attn_backend_override=attn_backend_override,
385389
)
386390

387391
self.mlp = Ernie4_5_VisionMLP(
@@ -458,6 +462,7 @@ def __init__(
458462
norm_eps: float = 1e-6,
459463
quant_config: QuantizationConfig | None = None,
460464
prefix: str = "",
465+
attn_backend_override: _Backend | None = None,
461466
) -> None:
462467
super().__init__()
463468
patch_size = vision_config.patch_size
@@ -493,6 +498,7 @@ def __init__(
493498
norm_layer=norm_layer,
494499
quant_config=quant_config,
495500
prefix=f"{prefix}.blocks.{layer_idx}",
501+
attn_backend_override=attn_backend_override,
496502
)
497503
for layer_idx in range(depth)
498504
]
@@ -504,7 +510,9 @@ def __init__(
504510
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
505511

506512
self.attn_backend = get_vit_attn_backend(
507-
head_size=head_dim, dtype=torch.get_default_dtype()
513+
head_size=head_dim,
514+
dtype=torch.get_default_dtype(),
515+
attn_backend_override=attn_backend_override,
508516
)
509517
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
510518
torch.get_default_dtype()
@@ -1327,11 +1335,17 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
13271335
self.config = config
13281336
self.multimodal_config = multimodal_config
13291337

1338+
attn_backend_override = (
1339+
multimodal_config.mm_encoder_attn_backend
1340+
if multimodal_config is not None
1341+
else None
1342+
)
13301343
self.vision_model = Ernie4_5_VisionTransformer(
13311344
config.vision_config,
13321345
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
13331346
quant_config=quant_config,
13341347
prefix=maybe_prefix(prefix, "vision_model"),
1348+
attn_backend_override=attn_backend_override,
13351349
)
13361350

13371351
self.language_model = Ernie4_5_VLMoeForCausalLM(

0 commit comments

Comments
 (0)