Skip to content

Commit 88b4f89

Browse files
committed
Merge branch 'main' into titaiwang/add_mask_generation
2 parents 2250887 + 27630e3 commit 88b4f89

File tree

7 files changed

+88
-14
lines changed

7 files changed

+88
-14
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
import numpy as np
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
5+
6+
7+
class TestIssues2025(ExtTestCase):
8+
@requires_torch("2.8")
9+
def test_issue_158786_qwen2vl(self):
10+
# https://github.com/pytorch/pytorch/issues/158786
11+
class Model(torch.nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.spatial_merge_size = 2 # Default
15+
16+
def forward(self, a):
17+
pos_ids = []
18+
for t, h, w in a:
19+
t = t.item()
20+
h = h.item()
21+
w = w.item()
22+
torch._constrain_as_size(t)
23+
torch._constrain_as_size(h)
24+
torch._constrain_as_size(w)
25+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
26+
hpos_ids = hpos_ids.reshape(
27+
h // self.spatial_merge_size,
28+
self.spatial_merge_size,
29+
w // self.spatial_merge_size,
30+
self.spatial_merge_size,
31+
)
32+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
33+
hpos_ids = hpos_ids.flatten()
34+
35+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
36+
wpos_ids = wpos_ids.reshape(
37+
h // self.spatial_merge_size,
38+
self.spatial_merge_size,
39+
w // self.spatial_merge_size,
40+
self.spatial_merge_size,
41+
)
42+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
43+
wpos_ids = wpos_ids.flatten()
44+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
45+
pos_ids = torch.cat(pos_ids, dim=0)
46+
return pos_ids
47+
48+
model = Model()
49+
inputs = torch.tensor(np.array([1, 98, 146]).reshape(1, 3))
50+
ep = torch.export.export(model, (inputs,))
51+
self.assertIn("torch.ops.aten.cat.default", str(ep))
52+
53+
54+
if __name__ == "__main__":
55+
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ class TestOnnxExportErrors(ExtTestCase):
2222
def test_pytree_flatten_mamba_cache(self):
2323
import torch
2424
import torch.utils._pytree as py_pytree
25-
from transformers.cache_utils import MambaCache
25+
26+
try:
27+
from transformers.models.mamba.modeling_mamba import MambaCache
28+
except ImportError:
29+
from transformers.cache_utils import MambaCache
2630

2731
class _config:
2832
def __init__(self):

onnx_diagnostic/helpers/_log_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,17 +260,17 @@ def open_dataframe(
260260
if isinstance(data, pandas.DataFrame):
261261
return data
262262
if isinstance(data, str):
263-
df = pandas.read_csv(data)
263+
df = pandas.read_csv(data, low_memory=False)
264264
df["RAWFILENAME"] = data
265265
return df
266266
if isinstance(data, tuple):
267267
if not data[-1]:
268-
df = pandas.read_csv(data[2])
268+
df = pandas.read_csv(data[2], low_memory=False)
269269
df["RAWFILENAME"] = data[2]
270270
return df
271271
zf = zipfile.ZipFile(data[-1])
272272
with zf.open(data[2]) as f:
273-
df = pandas.read_csv(f)
273+
df = pandas.read_csv(f, low_memory=False)
274274
df["RAWFILENAME"] = f"{data[-1]}/{data[2]}"
275275
zf.close()
276276
return df

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import transformers
55
import transformers.cache_utils
66

7+
try:
8+
from transformers.models.mamba.modeling_mamba import MambaCache
9+
except ImportError:
10+
from transformers.cache_utils import MambaCache
11+
712

813
def flatten_unflatten_for_dynamic_shapes(
914
obj: Any,
@@ -242,10 +247,8 @@ def make_encoder_decoder_cache(
242247
)
243248

244249

245-
def make_mamba_cache(
246-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
247-
) -> transformers.cache_utils.MambaCache:
248-
"Creates a :class:`transformers.cache_utils.MambaCache`."
250+
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
251+
"Creates a ``MambaCache``."
249252
dtype = key_value_pairs[0][0].dtype
250253

251254
class _config:
@@ -256,7 +259,7 @@ def __init__(self):
256259
self.num_hidden_layers = len(key_value_pairs)
257260
self.dtype = dtype
258261

259-
cache = transformers.cache_utils.MambaCache(
262+
cache = MambaCache(
260263
_config(),
261264
max_batch_size=key_value_pairs[0][0].shape[0],
262265
device=key_value_pairs[0][0].device,
@@ -286,7 +289,7 @@ def __init__(self):
286289

287290
def make_sliding_window_cache(
288291
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
289-
) -> transformers.cache_utils.MambaCache:
292+
) -> transformers.cache_utils.SlidingWindowCache:
290293
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
291294

292295
class _config:

onnx_diagnostic/tasks/text_generation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Callable, Dict, Optional, Tuple, Union
22
import torch
3-
import transformers
43
from ..helpers.cache_helper import (
54
make_dynamic_cache,
65
make_mamba_cache,
@@ -95,9 +94,14 @@ def get_inputs(
9594
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
9695

9796
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
97+
try:
98+
from transformers.models.mamba.modeling_mamba import MambaCache
99+
except ImportError:
100+
from transformers.cache_utils import MambaCache
101+
98102
assert cls_cache in (
99103
"MambaCache",
100-
transformers.cache_utils.MambaCache,
104+
MambaCache,
101105
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
102106
seq_length_multiple = 8
103107
sequence_length = (

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
import transformers
77
from transformers.cache_utils import (
88
DynamicCache,
9-
MambaCache,
109
EncoderDecoderCache,
1110
SlidingWindowCache,
1211
StaticCache,
1312
)
1413

14+
try:
15+
from transformers.models.mamba.modeling_mamba import MambaCache
16+
except ImportError:
17+
from transformers.cache_utils import MambaCache
18+
1519
from ..helpers import string_type
1620
from .serialization import _lower_name_with_
1721

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
import transformers
44
from transformers.cache_utils import (
55
DynamicCache,
6-
MambaCache,
76
EncoderDecoderCache,
87
SlidingWindowCache,
98
StaticCache,
109
)
10+
11+
try:
12+
from transformers.models.mamba.modeling_mamba import MambaCache
13+
except ImportError:
14+
from transformers.cache_utils import MambaCache
1115
from transformers.modeling_outputs import BaseModelOutput
1216
from ...helpers.cache_helper import make_static_cache
1317
from . import make_serialization_function_for_dataclass

0 commit comments

Comments
 (0)