Skip to content

Commit cce6f95

Browse files
committed
Initial support for Qwen2.5-VL
1 parent d0413b0 commit cce6f95

File tree

8 files changed

+222
-48
lines changed

8 files changed

+222
-48
lines changed

exllamav2/architecture.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ class Params:
356356

357357
# Qwen2-VL (2, 2.5)
358358

359-
if arch_string == "Qwen2VLForConditionalGeneration":
359+
if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:
360360
arch_recognized = True
361361
self.lm.layer_keys += \
362362
layer_keys_llama_norms + \
@@ -368,27 +368,44 @@ class Params:
368368
self.lm.mrope = True
369369
self.lm.rope_freq_half = True
370370

371-
read_config["vision_config"].update({"model_type": "qwen2"})
372371
self.vt_prefix = "visual."
373-
self.vt.keys.update({
374-
"fused_qkv": ".attn.qkv",
375-
"attn_o": ".attn.proj",
376-
"mlp_gate": None,
377-
"mlp_up": ".mlp.fc1",
378-
"mlp_down": ".mlp.fc2",
379-
"norm_1": ".norm1",
380-
"norm_2": ".norm2",
381-
"layers": "blocks",
382-
"patch_conv": "patch_embed.proj",
383-
})
384-
self.vt.mlp_gate = False
372+
if arch_string == "Qwen2VLForConditionalGeneration":
373+
read_config["vision_config"].update({"model_type": "qwen2"})
374+
self.vt.keys.update({
375+
"fused_qkv": ".attn.qkv",
376+
"attn_o": ".attn.proj",
377+
"mlp_gate": None,
378+
"mlp_up": ".mlp.fc1",
379+
"mlp_down": ".mlp.fc2",
380+
"norm_1": ".norm1",
381+
"norm_2": ".norm2",
382+
"layers": "blocks",
383+
"patch_conv": "patch_embed.proj",
384+
})
385+
self.vt.mlp_gate = False
386+
self.vt.mlp_act_func = "quickgelu"
387+
self.vt.norm = "layernorm"
388+
elif arch_string == "Qwen2_5_VLForConditionalGeneration":
389+
read_config["vision_config"].update({"model_type": "qwen2.5"})
390+
self.vt.keys.update({
391+
"fused_qkv": ".attn.qkv",
392+
"attn_o": ".attn.proj",
393+
"mlp_gate": ".mlp.gate_proj",
394+
"mlp_up": ".mlp.up_proj",
395+
"mlp_down": ".mlp.down_proj",
396+
"norm_1": ".norm1",
397+
"norm_2": ".norm2",
398+
"layers": "blocks",
399+
"patch_conv": "patch_embed.proj",
400+
})
401+
self.vt.mlp_gate = True
402+
self.vt.mlp_act_func = "silu"
403+
self.vt.norm = "rmsnorm"
385404
self.vt.mlp_bias = True
386405
self.vt.attention_bias_qkv = True
387406
self.vt.attention_bias_o = True
388407
self.vt.vision_input_norm = False
389408
self.vt.vision_conv3d = True
390-
self.vt.mlp_act_func = "quickgelu"
391-
self.vt.norm = "layernorm"
392409

393410
self.mmp_prefix = "visual.merger."
394411
self.mmp.keys.update({

exllamav2/attn.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@
4141
print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.")
4242

4343
if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
44-
from flash_attn import flash_attn_func
44+
from flash_attn import flash_attn_func, flash_attn_varlen_func
4545
has_flash_attn = True
4646

4747
if [2, 5, 7] <= flash_attn_ver:
48-
from flash_attn import flash_attn_func, flash_attn_with_kvcache
48+
from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
4949
# import flash_attn_2_cuda as flash_attn_cuda
5050

5151
signature = list(inspect.signature(flash_attn_func).parameters)
@@ -882,7 +882,9 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
882882
k_states = k_states[:, :, -self.sliding_window:, :]
883883
v_states = v_states[:, :, -self.sliding_window:, :]
884884

885-
if attn_params.is_causal():
885+
if self.layer_idx in attn_params.block_diag_layers:
886+
attn_mask_lr = attn_params.get_block_diag_mask(q_states.device)
887+
elif attn_params.is_causal():
886888
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
887889
else:
888890
attn_mask_lr = attn_params.get_attn_mask(q_states.device)
@@ -904,7 +906,9 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
904906
attn_weights = torch.matmul(q_states, k_states)
905907

906908
attn_weights *= self.scaling
907-
if causal:
909+
if self.layer_idx in attn_params.block_diag_layers:
910+
attn_mask = attn_params.get_block_diag_mask(attn_weights.device)
911+
elif causal:
908912
attn_mask = attn_params.get_attn_mask(attn_weights.device)
909913

910914
if cfg.attn_logit_softcapping:
@@ -939,14 +943,30 @@ def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_para
939943
if has_flash_attn_with_softcap:
940944
flash_kwargs["softcap"] = cfg.attn_logit_softcapping
941945

942-
attn_output = flash_attn_func(
943-
q_states,
944-
k_states,
945-
v_states,
946-
causal = causal,
947-
softmax_scale = self.scaling,
948-
**flash_kwargs
949-
)
946+
if self.layer_idx in attn_params.block_diag_layers:
947+
q_states = q_states.flatten(start_dim = 0, end_dim = 1)
948+
k_states = k_states.flatten(start_dim = 0, end_dim = 1)
949+
v_states = v_states.flatten(start_dim = 0, end_dim = 1)
950+
max_seqlen = attn_params.get_cu_seqlens_max()
951+
cu_seqlens = attn_params.get_cu_seqlens(self.device_idx)
952+
attn_output = flash_attn_varlen_func(
953+
q_states,
954+
k_states,
955+
v_states,
956+
cu_seqlens,
957+
cu_seqlens,
958+
max_seqlen,
959+
max_seqlen
960+
)
961+
else:
962+
attn_output = flash_attn_func(
963+
q_states,
964+
k_states,
965+
v_states,
966+
causal = causal,
967+
softmax_scale = self.scaling,
968+
**flash_kwargs
969+
)
950970
attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim))
951971
return attn_output
952972

exllamav2/attn_params.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class Params:
2121
alt_rope_embed_dict: dict | None
2222
rope_offsets: torch.Tensor | None
2323
non_causal_attn: bool
24+
block_diag_layers: set
25+
block_diag_mask: torch.Tensor | None
26+
cu_seqlens: torch.Tensor | None
27+
cu_seqlens_max: int | None
2428

2529
def __init__(
2630
self,
@@ -66,6 +70,11 @@ def __init__(
6670
self.past_len_tp = None
6771
self.paged = paged
6872

73+
self.block_diag_layers = set()
74+
self.block_diag_mask = None
75+
self.cu_seqlens = None
76+
self.cu_seqlens_max = None
77+
6978
def is_causal(self) -> bool:
7079
return self.input_mask is None
7180

@@ -164,6 +173,31 @@ def get_rope_offsets(self, device_idx: int) -> torch.Tensor | None:
164173
self.rope_offsets = safe_move_tensor(self.rope_offsets, device_idx, non_blocking = True)
165174
return self.rope_offsets
166175

176+
def get_cu_seqlens(self, device: int) -> torch.Tensor | None:
177+
if self.cu_seqlens is None:
178+
return None
179+
if self.cu_seqlens.device.index != device:
180+
self.cu_seqlens = safe_move_tensor(self.cu_seqlens, device, non_blocking = True)
181+
return self.cu_seqlens
182+
183+
def get_cu_seqlens_max(self) -> torch.Tensor | None:
184+
assert self.cu_seqlens is not None
185+
if self.cu_seqlens_max is not None:
186+
return self.cu_seqlens_max
187+
self.cu_seqlens_max = (self.cu_seqlens[1:] - self.cu_seqlens[:-1]).max().item()
188+
return self.cu_seqlens_max
189+
190+
def get_block_diag_mask(self, device: int) -> torch.Tensor | None:
191+
if self.block_diag_mask is None:
192+
csl = self.get_cu_seqlens(device)
193+
if csl is None:
194+
return None
195+
positions = torch.arange(csl[-1], device = csl.device)
196+
labels = torch.searchsorted(csl[1:], positions, right = True)
197+
self.block_diag_mask = labels.unsqueeze(0) == labels.unsqueeze(1).repeat(self.batch_size)
198+
if self.block_diag_mask.device.index != device:
199+
self.block_diag_mask = safe_move_tensor(self.block_diag_mask, device, non_blocking = True)
200+
return self.block_diag_mask
167201

168202

169203
class PagedParams(Params):

exllamav2/config.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class ExLlamaV2Config:
135135
vision_num_key_value_groups: int | None
136136
vision_hidden_size: int | None
137137
vision_intermediate_size: int | None
138+
vision_merger_intermediate_size: int | None
138139
vision_hidden_act: str | None
139140
vision_rope_theta: float | None
140141
vision_feature_layer: int | None
@@ -152,6 +153,8 @@ class ExLlamaV2Config:
152153
vision_max_pixels: int | None
153154
vision_temporal_patch_size: int | None
154155
vision_max_size: int | None
156+
vision_fullatt_block_indexes: list | None
157+
vision_window_size: int | None
155158

156159
# Deprecated fields, kept for compatibiltiy
157160

@@ -478,6 +481,8 @@ def check_keys(archparams, prefix):
478481

479482
# TODO: Cleanup & refactor
480483

484+
self.vision_fullatt_block_indexes = None
485+
481486
if self.vision_model_type is None:
482487
pass
483488

@@ -495,6 +500,7 @@ def check_keys(archparams, prefix):
495500
self.vision_feature_layer = read(read_config, int, ["vision_feature_layer"], no_default)
496501
self.vision_num_layers = read(read_config, int, ["vision_config->num_hidden_layers"], 24)
497502
self.vision_intermediate_size = read(read_config, int, ["vision_config->intermediate_size"], self.hidden_size)
503+
self.vision_merger_intermediate_size = self.vision_intermediate_size
498504

499505
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
500506
assert image_processor_type == "PixtralImageProcessor", \
@@ -511,10 +517,27 @@ def check_keys(archparams, prefix):
511517
self.vision_spatial_merge_size = 1
512518
self.vision_max_size = 16384
513519

514-
elif self.vision_model_type == "qwen2":
520+
elif self.vision_model_type in ["qwen2", "qwen2.5"]:
521+
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
522+
if self.vision_model_type == "qwen2":
523+
self.vision_hidden_size = read(read_config, int, ["vision_config->embed_dim"], no_default)
524+
mlp_ratio = read(read_config, int, ["vision_config->mlp_ratio"], None)
525+
self.vision_intermediate_size = self.vision_hidden_size * mlp_ratio
526+
self.vision_merger_intermediate_size = self.vision_intermediate_size
527+
assert image_processor_type == "Qwen2VLImageProcessor", \
528+
f"Wrong image processor type: {image_processor_type}"
529+
self.vision_window_size = None
530+
elif self.vision_model_type == "qwen2.5":
531+
self.vision_hidden_size = read(read_config, int, ["vision_config->hidden_size"], no_default)
532+
self.vision_intermediate_size = read(read_config, int, ["vision_config->intermediate_size"], no_default)
533+
self.vision_fullatt_block_indexes = read(read_config, list, ["vision_config->fullatt_block_indexes", None])
534+
self.vision_window_size = read(read_config, int, ["vision_config->window_size", None])
535+
assert image_processor_type == "Qwen2_5_VLImageProcessor", \
536+
f"Wrong image processor type: {image_processor_type}"
537+
self.vision_merger_intermediate_size = 5120 # TODO: This doesn't seem to appear in the config anywhere?
538+
515539
self.vision_num_attention_heads = read(read_config, int, ["vision_config->num_heads"], no_default)
516540
self.vision_num_key_value_heads = self.vision_num_attention_heads
517-
self.vision_hidden_size = read(read_config, int, ["vision_config->embed_dim"], no_default)
518541
self.vision_head_dim = self.vision_hidden_size // self.vision_num_attention_heads
519542
self.vision_num_key_value_groups = 1
520543
self.vision_hidden_act = "quickgelu"
@@ -523,12 +546,7 @@ def check_keys(archparams, prefix):
523546
patch_size = read(read_config, int, ["vision_config->patch_size"], no_default)
524547
self.vision_rope_theta = read(read_config, int, ["vision_config->rope_theta"], 10000.0)
525548
self.vision_num_layers = read(read_config, int, ["vision_config->depth"], no_default)
526-
mlp_ratio = read(read_config, int, ["vision_config->mlp_ratio"], no_default)
527-
self.vision_intermediate_size = self.vision_hidden_size * mlp_ratio
528549

529-
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
530-
assert image_processor_type == "Qwen2VLImageProcessor", \
531-
f"Wrong image processor type: {image_processor_type}"
532550
self.vision_image_mean = read(read_prep_config, list, ["image_mean"], no_default)
533551
self.vision_image_std = read(read_prep_config, list, ["image_std"], no_default)
534552
assert read(read_prep_config, int, ["patch_size"], no_default) == patch_size, \

exllamav2/mlp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
out_features: int | None = None,
5252
interm_features: int | None = None,
5353
merge: int | None = None,
54+
pad32: bool = True,
5455
):
5556
super().__init__(model, key, archparams)
5657
cfg = self.model.config
@@ -98,8 +99,8 @@ def __init__(
9899
self.pre_layernorm = None
99100
self.post_layernorm = None
100101

101-
self.up_proj = ExLlamaV2Linear(model, key + km["mlp_up"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c)
102-
self.down_proj = ExLlamaV2Linear(model, key + km["mlp_down"], interm_features, out_features, ap.mlp_bias, prescale = cfg.scale_depth)
102+
self.up_proj = ExLlamaV2Linear(model, key + km["mlp_up"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c, pad32 = pad32)
103+
self.down_proj = ExLlamaV2Linear(model, key + km["mlp_down"], interm_features, out_features, ap.mlp_bias, prescale = cfg.scale_depth, pad32 = pad32)
103104

104105
self.submodules = [self.up_proj,
105106
self.down_proj]
@@ -109,7 +110,7 @@ def __init__(
109110
self.submodules += [self.post_layernorm]
110111

111112
if ap.mlp_gate:
112-
self.gate_proj = ExLlamaV2Linear(model, key + km["mlp_gate"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b)
113+
self.gate_proj = ExLlamaV2Linear(model, key + km["mlp_gate"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b, pad32 = pad32)
113114
self.submodules += [self.gate_proj]
114115
else:
115116
self.gate_proj = None

exllamav2/vlm/processor/pixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def preprocess(
4444

4545
image = image.transpose(2, 0, 1)
4646
image = torch.from_numpy(image).half()
47-
return image, new_size
47+
return image, new_size, None, None
4848

4949
def postprocess(
5050
model: ExLlamaV2,

exllamav2/vlm/processor/qwen2.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import torch
4+
import torch.nn.functional as F
45
import numpy as np
56
from PIL import Image
67
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
@@ -86,7 +87,7 @@ def preprocess(
8687

8788
if mode == "image":
8889
image = torch.from_numpy(flatten_patches).half()
89-
return image, new_size
90+
return image, new_size, (grid_t, grid_h, grid_w), config.vision_spatial_patch_size ** 2
9091
else:
9192
video = torch.from_numpy(flatten_patches).half()
9293
return video, new_size, (grid_t, grid_h, grid_w), config.vision_spatial_patch_size ** 2
@@ -149,4 +150,51 @@ def position_embeddings(
149150
cos = cos.unsqueeze(1).repeat(1, 1, 2).contiguous()
150151
sin = sin.unsqueeze(1).repeat(1, 1, 2).contiguous()
151152

152-
return sin, cos
153+
return sin, cos
154+
155+
156+
def get_window_index(grid_thw, config: ExLlamaV2Config):
157+
158+
window_index: list = []
159+
cu_window_seqlens: list = [0]
160+
window_index_id = 0
161+
vit_merger_window_size = (
162+
config.vision_window_size //
163+
config.vision_spatial_merge_size //
164+
config.vision_patch_size["height"]
165+
)
166+
167+
for grid_t, grid_h, grid_w in grid_thw:
168+
llm_grid_h, llm_grid_w = (
169+
grid_h // config.vision_spatial_merge_size,
170+
grid_w // config.vision_spatial_merge_size,
171+
)
172+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
173+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
174+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
175+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
176+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
177+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
178+
index_padded = index_padded.reshape(
179+
grid_t,
180+
num_windows_h,
181+
vit_merger_window_size,
182+
num_windows_w,
183+
vit_merger_window_size,
184+
)
185+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
186+
grid_t,
187+
num_windows_h * num_windows_w,
188+
vit_merger_window_size,
189+
vit_merger_window_size,
190+
)
191+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
192+
index_padded = index_padded.reshape(-1)
193+
index_new = index_padded[index_padded != -100]
194+
window_index.append(index_new + window_index_id)
195+
cu_seqlens_tmp = seqlens.cumsum(0) * config.vision_spatial_merge_size**2 + cu_window_seqlens[-1]
196+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
197+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
198+
199+
window_index = torch.cat(window_index, dim =0)
200+
return window_index, cu_window_seqlens

0 commit comments

Comments
 (0)