Skip to content

Commit fc04fa7

Browse files
committed
more patches
1 parent bf69e8c commit fc04fa7

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _config_reduction(config, task):
8282
)
8383
dynamic_shapes = dict(
8484
hidden_states={0: "hidden_width", 1: "hidden_height"},
85-
grid_thw={0: "n_images"},
85+
grid_thw={}, # {0: "n_images"}, # TODO: fix
8686
)
8787

8888
# fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,11 +2090,44 @@ def prepare_inputs_for_generation(
20902090
return model_inputs
20912091

20922092
class patched_Qwen2_5_VisionTransformerPretrainedModel:
2093-
_PATCHES_ = ["get_window_index", "forward"]
2093+
_PATCHES_ = ["get_window_index", "forward", "rot_pos_emb"]
20942094
_PATCHED_CLASS_ = (
20952095
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel
20962096
)
20972097

2098+
def rot_pos_emb(self, grid_thw):
2099+
pos_ids = []
2100+
for thw in grid_thw:
2101+
# PATCHED: avoid unbind
2102+
t = thw[0]
2103+
h = thw[1]
2104+
w = thw[2]
2105+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
2106+
hpos_ids = hpos_ids.reshape(
2107+
h // self.spatial_merge_size,
2108+
self.spatial_merge_size,
2109+
w // self.spatial_merge_size,
2110+
self.spatial_merge_size,
2111+
)
2112+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
2113+
hpos_ids = hpos_ids.flatten()
2114+
2115+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
2116+
wpos_ids = wpos_ids.reshape(
2117+
h // self.spatial_merge_size,
2118+
self.spatial_merge_size,
2119+
w // self.spatial_merge_size,
2120+
self.spatial_merge_size,
2121+
)
2122+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
2123+
wpos_ids = wpos_ids.flatten()
2124+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
2125+
pos_ids = torch.cat(pos_ids, dim=0)
2126+
max_grid_size = grid_thw[:, 1:].max()
2127+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
2128+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
2129+
return rotary_pos_emb
2130+
20982131
def get_window_index(self, grid_thw):
20992132
window_index: list = []
21002133
# PATCHED
@@ -2104,7 +2137,11 @@ def get_window_index(self, grid_thw):
21042137
self.window_size // self.spatial_merge_size // self.patch_size
21052138
)
21062139

2107-
for grid_t, grid_h, grid_w in grid_thw:
2140+
for _thw in grid_thw:
2141+
# PATCHED: avoid unbind
2142+
grid_t = _thw[0]
2143+
grid_h = _thw[1]
2144+
grid_w = _thw[2]
21082145
llm_grid_h, llm_grid_w = (
21092146
grid_h // self.spatial_merge_size,
21102147
grid_w // self.spatial_merge_size,
@@ -2279,12 +2316,13 @@ def forward(
22792316
**kwargs,
22802317
) -> torch.Tensor:
22812318
seq_length = hidden_states.shape[0]
2282-
query_states, key_states, value_states = (
2319+
# PATCHED: avoid the use of unbind
2320+
qkv = (
22832321
self.qkv(hidden_states)
22842322
.reshape(seq_length, 3, self.num_heads, -1)
22852323
.permute(1, 0, 2, 3)
2286-
.unbind(0)
22872324
)
2325+
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
22882326
cos, sin = position_embeddings
22892327
query_states, key_states = (
22902328
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(

0 commit comments

Comments
 (0)