Skip to content

Commit bd2c354

Browse files
committed
Fixed ros_embed and added multi vision config
Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com>
1 parent 7e21c0f commit bd2c354

File tree

3 files changed

+291
-63
lines changed

3 files changed

+291
-63
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,7 @@ def compile(
12071207
use_onnx_subfunctions=use_onnx_subfunctions,
12081208
**compiler_options,
12091209
)
1210-
1210+
# breakpoint()
12111211
# Custom NPI file options
12121212
if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options:
12131213
compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path)
@@ -1435,6 +1435,15 @@ def kv_offload_generate(
14351435

14361436
vision_inputs_fp16 = {"pixel_values", "image_masks"}
14371437
vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs})
1438+
pixel_values_shape = list(vision_inputs["pixel_values"].shape)
1439+
breakpoint()
1440+
idx = next(i for i, inner in enumerate(vision_session.allowed_shapes) if (2, pixel_values_shape) in inner)
1441+
1442+
biffer_set = {
1443+
"vision_embeds": np.zeros(vision_session.allowed_shapes[idx][2][1], dtype=np.float16),
1444+
"image_grid_thw": np.zeros(vision_session.allowed_shapes[idx][0][1], dtype=np.int64),
1445+
}
1446+
vision_session.set_buffers(biffer_set)
14381447

14391448
vision_start = perf_counter()
14401449

@@ -1461,6 +1470,16 @@ def kv_offload_generate(
14611470
vision_session.deactivate()
14621471
lang_session.activate()
14631472

1473+
vision_outputs["vision_embeds"] = np.pad(
1474+
vision_outputs["vision_embeds"],
1475+
pad_width=(
1476+
(0, 0),
1477+
(0, lang_session.allowed_shapes[0][1][1][1] - vision_session.allowed_shapes[idx][2][1][1]),
1478+
(0, 0),
1479+
), # pad axis=1 only
1480+
mode="constant",
1481+
constant_values=0,
1482+
)
14641483
lang_session.set_buffers(vision_outputs)
14651484

14661485
if self.comp_ctx_lengths_prefill is not None:

QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py

Lines changed: 119 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu
6868
Returns:
6969
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
7070
"""
71-
71+
# breakpoint()
7272
cos = cos.unsqueeze(unsqueeze_dim)
7373
sin = sin.unsqueeze(unsqueeze_dim)
7474
q_embed = (q * cos) + (rotate_half(q) * sin)
@@ -144,8 +144,43 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu
144144

145145

146146
class QEffQwen3VLVisionModel(Qwen3VLVisionModel):
147+
# def rot_pos_emb(self, grid_thw):
148+
# pos_ids = []
149+
150+
# bs, t, h, w = grid_thw.shape
151+
152+
# hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
153+
# hpos_ids = hpos_ids.reshape(
154+
# h // self.spatial_merge_size,
155+
# self.spatial_merge_size,
156+
# w // self.spatial_merge_size,
157+
# self.spatial_merge_size,
158+
# )
159+
# hpos_ids = hpos_ids.permute(0, 2, 1, 3)
160+
# hpos_ids = hpos_ids.flatten()
161+
162+
# wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
163+
# wpos_ids = wpos_ids.reshape(
164+
# h // self.spatial_merge_size,
165+
# self.spatial_merge_size,
166+
# w // self.spatial_merge_size,
167+
# self.spatial_merge_size,
168+
# )
169+
# wpos_ids = wpos_ids.permute(0, 2, 1, 3)
170+
# wpos_ids = wpos_ids.flatten()
171+
# pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
172+
# pos_ids = torch.cat(pos_ids, dim=0)
173+
174+
# x_expanded = pos_ids.unsqueeze(0)
175+
# x_expanded = x_expanded.expand(bs, -1, -1)
176+
# pos_ids = x_expanded.reshape(-1, pos_ids.size(1))
177+
178+
# max_grid_size = max(grid_thw.shape)
179+
# rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
180+
# rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
181+
# return rotary_pos_emb
147182
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
148-
# ##breakpoint()
183+
# breakpoint()
149184
merge_size = self.spatial_merge_size
150185

151186
# max_hw = int(grid_thw[:, 1:].max().item())
@@ -160,7 +195,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
160195
total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
161196
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
162197

163-
offset = 0
198+
# offset = 0
164199
# breakpoint()
165200

166201
# for bs,num_frames, height, width in grid_thw:
@@ -182,33 +217,39 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
182217

183218
if num_frames > 1:
184219
coords = coords.repeat(num_frames, 1)
185-
186-
num_tokens = coords.shape[0]
187-
pos_ids[offset : offset + num_tokens] = coords
188-
offset += num_tokens
189-
220+
# breakpoint()
221+
# num_tokens = len(coords)
222+
# pos_ids[offset : offset + num_tokens] = coords
223+
# offset += num_tokens
224+
# breakpoint()
225+
# pos_ids = torch.cat([pos_ids, coords], dim=-1)
226+
pos_ids = coords
190227
embeddings = freq_table[pos_ids] # lookup rotary embeddings
191228
embeddings = embeddings.flatten(1)
192229
return embeddings
193230

194231
def fast_pos_embed_interpolate(self, grid_thw):
195-
# breakpoint()
196-
# gridbs, grid_ts, grid_hs, grid_ws = grid_thw.shape
197232
bs, t, h, w = grid_thw.shape
198233
# grid_ts = torch.tensor([grid_ts], device=grid_thw.device)
199234
# grid_hs = torch.tensor([grid_hs], device=grid_thw.device)
200235
# grid_ws = torch.tensor([grid_ws], device=grid_thw.device)
201-
idx_list = [[] for _ in range(4)]
202-
weight_list = [[] for _ in range(4)]
236+
# idx_list = [[] for _ in range(4)]
237+
# weight_list = [[] for _ in range(4)]
203238
# t,h,w = grid_ts[0],grid_hs[0],grid_ws[0]
204239
# for t, h, w in zip(grid_ts, grid_hs, grid_ws):
205240
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
206241
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
207242

208243
h_idxs_floor = h_idxs.int()
209244
w_idxs_floor = w_idxs.int()
210-
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
211-
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
245+
# h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
246+
# w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
247+
# h_idxs_ceil = torch.clip(h_idxs.int(),0,self.num_grid_per_side - 1)
248+
# w_idxs_ceil = torch.clip(w_idxs.int(),0,self.num_grid_per_side - 1)
249+
max_t = torch.tensor(self.num_grid_per_side - 1, device=h_idxs.device)
250+
251+
h_idxs_ceil = torch.minimum(h_idxs_floor + 1, max_t) # working
252+
w_idxs_ceil = torch.minimum(w_idxs_floor + 1, max_t)
212253

213254
dh = h_idxs - h_idxs_floor
214255
dw = w_idxs - w_idxs_floor
@@ -230,12 +271,10 @@ def fast_pos_embed_interpolate(self, grid_thw):
230271
(dh[None].T * dw[None]).flatten(),
231272
]
232273

233-
for i in range(4):
234-
idx_list[i].extend(indices[i].tolist())
235-
weight_list[i].extend(weights[i].tolist())
236-
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
237-
weight_tensor = torch.tensor(
238-
weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
274+
idx_tensor = torch.stack(indices, dim=0).to(dtype=torch.long, device=self.pos_embed.weight.device) # [4, h*w]
275+
276+
weight_tensor = torch.stack(weights, dim=0).to(
277+
dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
239278
)
240279
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
241280
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
@@ -244,9 +283,7 @@ def fast_pos_embed_interpolate(self, grid_thw):
244283

245284
patch_pos_embeds_permute = []
246285
merge_size = self.config.spatial_merge_size
247-
# breakpoint()
248286
pos_embed = patch_pos_embeds[0]
249-
# for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
250287
pos_embed = pos_embed.repeat(t, 1)
251288

252289
pos_embed = (
@@ -262,20 +299,18 @@ def fast_pos_embed_interpolate(self, grid_thw):
262299
return patch_pos_embeds
263300

264301
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
265-
# ##breakpoint()
266302
hidden_states = self.patch_embed(hidden_states)
267303
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
268-
# breakpoint()
304+
269305
hidden_states = hidden_states + pos_embeds
270-
# breakpoint()
306+
271307
rotary_pos_emb = self.rot_pos_emb(grid_thw)
272308

273309
seq_len, _ = hidden_states.size()
274310
hidden_states = hidden_states.reshape(seq_len, -1)
275311
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
276312
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
277313
position_embeddings = (emb.cos(), emb.sin())
278-
# ##breakpoint()
279314
bs, t, h, w = grid_thw.shape
280315

281316
t = torch.arange(t, t + 1).squeeze().expand(bs)
@@ -286,7 +321,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
286321
dim=0,
287322
dtype=torch.int32,
288323
)
289-
# ##breakpoint()
290324
cu_seqlens = torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), cu_seqlens])
291325

292326
deepstack_feature_lists = []
@@ -301,9 +335,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
301335
hidden_states
302336
)
303337
deepstack_feature_lists.append(deepstack_feature)
304-
# ##breakpoint()
305338
hidden_states = self.merger(hidden_states)
306-
# ##breakpoint()
307339
return hidden_states, deepstack_feature_lists
308340

309341

@@ -322,7 +354,7 @@ def forward(
322354
rotary_pos_emb: Optional[torch.Tensor] = None,
323355
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
324356
) -> torch.Tensor:
325-
# ##breakpoint()
357+
# breakpoint()
326358
seq_length = hidden_states.shape[0]
327359
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
328360
if position_embeddings is None:
@@ -390,7 +422,7 @@ def eager_attention_forward(
390422
past_key_value: Optional[Cache] = None,
391423
**kwargs,
392424
):
393-
# ##breakpoint()
425+
# breakpoint()
394426
key_states = repeat_kv(key, module.num_key_value_groups)
395427
value_states = repeat_kv(value, module.num_key_value_groups)
396428

@@ -425,7 +457,7 @@ def forward(
425457
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
426458
**kwargs,
427459
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
428-
##breakpoint()
460+
# breakpoint()
429461
input_shape = hidden_states.shape[:-1]
430462
hidden_shape = (*input_shape, -1, self.head_dim)
431463
bsz, q_len, _ = hidden_states.size()
@@ -898,8 +930,9 @@ def get_specializations(
898930
prefill_seq_len: int,
899931
ctx_len: int,
900932
img_size: None,
901-
height: int = None,
902-
width: int = None,
933+
# height: int = None,
934+
# width: int = None,
935+
dimensions: List = None,
903936
num_frames: int = 1,
904937
kv_offload: bool = False,
905938
continuous_batching: bool = False,
@@ -910,12 +943,13 @@ def get_specializations(
910943
# ##breakpoint()
911944
comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None)
912945
comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None)
913-
if height is None or width is None:
946+
if dimensions is None:
914947
height = 1365
915948
width = 2048
916949
logger.warning(
917950
"Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config"
918951
)
952+
dimensions = [[height, width]]
919953
prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
920954
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
921955
channel = 3
@@ -972,30 +1006,53 @@ def smart_resize(
9721006
w_bar = ceil_by_factor(width * beta, factor)
9731007
return h_bar, w_bar
9741008

975-
resized_height, resized_width = smart_resize(height=height, width=width)
976-
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
977-
grid_height = grid_h * grid_w
978-
grid_width = patch_size * patch_size * temporal_patch_size * channel
979-
vision_size = grid_height // 4
980-
vision_size = vision_size * num_frames
981-
grid_height = grid_height * batch_size
1009+
# resized_height, resized_width = smart_resize(height=height, width=width)
1010+
# grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
1011+
# grid_height = grid_h * grid_w
1012+
# grid_width = patch_size * patch_size * temporal_patch_size * channel
1013+
# vision_size = grid_height // 4
1014+
# vision_size = vision_size * num_frames
1015+
# grid_height = grid_height * batch_size
1016+
# # breakpoint()
1017+
# # vision_size = 176
1018+
# # grid_height = 704
1019+
# # grid_width = 1536
1020+
# # grid_h = 22
1021+
# # grid_w = 32
1022+
# # breakpoint()
1023+
vision = []
1024+
max_vision_size = 0
9821025
# breakpoint()
983-
# vision_size = 176
984-
# grid_height = 704
985-
# grid_width = 1536
986-
# grid_h = 22
987-
# grid_w = 32
988-
# breakpoint()
989-
vision = [
990-
{
991-
"batch_size": batch_size,
992-
"vision_size": vision_size,
993-
"grid_height": grid_height,
994-
"grid_width": grid_width,
995-
"grid_h": grid_h,
996-
"grid_w": grid_w,
997-
}
998-
]
1026+
for dimension in dimensions:
1027+
resized_height, resized_width = smart_resize(height=dimension[0], width=dimension[1])
1028+
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
1029+
grid_height = grid_h * grid_w
1030+
grid_width = patch_size * patch_size * temporal_patch_size * channel
1031+
vision_size = grid_height // 4
1032+
vision_size = vision_size * num_frames
1033+
grid_height = grid_height * batch_size
1034+
1035+
max_vision_size = max(max_vision_size, vision_size)
1036+
# vision = [
1037+
# {
1038+
# "batch_size": batch_size,
1039+
# "vision_size": max_vision_size,
1040+
# "grid_height": grid_height,
1041+
# "grid_width": grid_width,
1042+
# "grid_h": grid_h,
1043+
# "grid_w": grid_w,
1044+
# }
1045+
# ]
1046+
vision.append(
1047+
{
1048+
"batch_size": batch_size,
1049+
"vision_size": vision_size,
1050+
"grid_height": grid_height,
1051+
"grid_width": grid_width,
1052+
"grid_h": grid_h,
1053+
"grid_w": grid_w,
1054+
}
1055+
)
9991056
# ##breakpoint()
10001057

10011058
if comp_ctx_lengths_prefill is not None:
@@ -1006,7 +1063,7 @@ def smart_resize(
10061063
"batch_size": 1 if continuous_batching else batch_size,
10071064
"seq_len": prefill_seq_len,
10081065
"ctx_len": ctx_len,
1009-
"vision_size": vision_size,
1066+
"vision_size": max_vision_size,
10101067
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
10111068
"vision_batch_size": batch_size,
10121069
}
@@ -1025,7 +1082,7 @@ def smart_resize(
10251082
"batch_size": full_batch_size if continuous_batching else batch_size,
10261083
"seq_len": "1",
10271084
"ctx_len": ctx_len,
1028-
"vision_size": vision_size,
1085+
"vision_size": max_vision_size,
10291086
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
10301087
"vision_batch_size": batch_size,
10311088
}
@@ -1041,7 +1098,7 @@ def smart_resize(
10411098
"batch_size": 1 if continuous_batching else batch_size,
10421099
"seq_len": prefill_seq_len,
10431100
"ctx_len": ctx_len,
1044-
"vision_size": vision_size,
1101+
"vision_size": max_vision_size,
10451102
"vision_batch_size": batch_size,
10461103
}
10471104

@@ -1056,7 +1113,7 @@ def smart_resize(
10561113
"batch_size": full_batch_size if continuous_batching else batch_size,
10571114
"seq_len": 1,
10581115
"ctx_len": ctx_len,
1059-
"vision_size": vision_size,
1116+
"vision_size": max_vision_size,
10601117
"vision_batch_size": batch_size,
10611118
}
10621119

0 commit comments

Comments
 (0)