@@ -256,6 +256,7 @@ def __init__(
256256 quant_config : QuantizationConfig | None = None ,
257257 prefix : str = "" ,
258258 use_data_parallel : bool = False ,
259+ attn_backend_override : _Backend | None = None ,
259260 ) -> None :
260261 super ().__init__ ()
261262
@@ -288,7 +289,9 @@ def __init__(
288289 )
289290 # Select attention backend
290291 self .attn_backend = get_vit_attn_backend (
291- self .hidden_size_per_attention_head , torch .get_default_dtype ()
292+ self .hidden_size_per_attention_head ,
293+ torch .get_default_dtype (),
294+ attn_backend_override = attn_backend_override ,
292295 )
293296 self .use_upstream_fa = False
294297
@@ -510,6 +513,7 @@ def __init__(
510513 quant_config : QuantizationConfig | None = None ,
511514 prefix : str = "" ,
512515 use_data_parallel : bool = False ,
516+ attn_backend_override : _Backend | None = None ,
513517 ):
514518 super ().__init__ ()
515519
@@ -521,6 +525,7 @@ def __init__(
521525 quant_config = quant_config ,
522526 prefix = f"{ prefix } .attn" ,
523527 use_data_parallel = use_data_parallel ,
528+ attn_backend_override = attn_backend_override ,
524529 )
525530 self .norm1 = RMSNorm (config .embed_dim , eps = config .rms_norm_eps )
526531 self .mlp = DotsSwiGLUFFN (
@@ -561,6 +566,7 @@ def __init__(
561566 require_post_norm : bool | None = None ,
562567 prefix : str = "" ,
563568 use_data_parallel : bool = False ,
569+ attn_backend_override : _Backend | None = None ,
564570 ) -> None :
565571 super ().__init__ ()
566572 self .config = config
@@ -571,7 +577,9 @@ def __init__(
571577 head_dim = config .embed_dim // config .num_attention_heads
572578 self .rotary_pos_emb = VisionRotaryEmbedding (head_dim // 2 )
573579 self .attn_backend = get_vit_attn_backend (
574- head_size = head_dim , dtype = torch .get_default_dtype ()
580+ head_size = head_dim ,
581+ dtype = torch .get_default_dtype (),
582+ attn_backend_override = attn_backend_override ,
575583 )
576584 if self .attn_backend != _Backend .FLASH_ATTN and check_upstream_fa_availability (
577585 torch .get_default_dtype ()
@@ -591,6 +599,7 @@ def __init__(
591599 quant_config = quant_config ,
592600 prefix = f"{ prefix } .blocks.{ i } " ,
593601 use_data_parallel = use_data_parallel ,
602+ attn_backend_override = attn_backend_override ,
594603 )
595604 for i in range (num_layers )
596605 ]
@@ -750,11 +759,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
750759 self .config .vision_config = vision_config
751760 else :
752761 vision_config = self .config .vision_config
762+ attn_backend_override = (
763+ multimodal_config .mm_encoder_attn_backend
764+ if multimodal_config is not None
765+ else None
766+ )
753767 self .vision_tower = DotsVisionTransformer (
754768 vision_config ,
755769 quant_config = self .quant_config ,
756770 prefix = maybe_prefix (prefix , "vision_tower" ),
757771 use_data_parallel = self .use_data_parallel ,
772+ attn_backend_override = attn_backend_override ,
758773 )
759774 self .language_model : Qwen2ForCausalLM = init_vllm_registered_model (
760775 vllm_config = vllm_config ,
0 commit comments