Skip to content

Commit 94a9b10

Browse files
committed
Merge branch 'main' into titaiwang/fix_modelbuilder_discrepancy
2 parents 65f1ca0 + c7afba2 commit 94a9b10

File tree

14 files changed

+625
-87
lines changed

14 files changed

+625
-87
lines changed

.github/workflows/documentation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ jobs:
118118
grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export'
119119
exit 1
120120
fi
121-
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string') ]]; then
121+
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache') ]]; then
122122
echo "Documentation produces warnings."
123-
grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string'
123+
grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache'
124124
exit 1
125125
fi
126126

CHANGELOGS.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Change Logs
44
0.7.11
55
++++++
66

7+
* :pr:`224`: support model_id with // to specify a subfolder
8+
* :pr:`223`: adds task image-to-video
9+
* :pr:`220`: adds option --ort-logs to display onnxruntime logs when creating the session
710
* :pr:`220`: adds a patch for PR `#40791 <https://github.com/huggingface/transformers/pull/40791>`_ in transformers
811

912
0.7.10
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
Compares two ONNX models.
3+
"""
4+
5+
print("-- import onnx")
6+
import onnx
7+
8+
print("-- import onnx.helper")
9+
from onnx.helper import tensor_dtype_to_np_dtype
10+
11+
print("-- import onnxruntime")
12+
import onnxruntime
13+
14+
print("-- import torch")
15+
import torch
16+
17+
print("-- import transformers")
18+
import transformers
19+
20+
print("-- import huggingface_hub")
21+
import huggingface_hub
22+
23+
print("-- import onnx-diagnostic.helper")
24+
from onnx_diagnostic.helpers.helper import flatten_object, string_type, max_diff, string_diff
25+
26+
print("-- import onnx-diagnostic.torch_models.hghub")
27+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
28+
29+
print("-- done")
30+
31+
model_id = "arnir0/Tiny-LLM"
32+
onnx1 = (
33+
"dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op20/"
34+
"arnir0_Tiny-LLM-custom-default-f16-cuda-op20.onnx"
35+
)
36+
onnx2 = (
37+
"dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op21/"
38+
"arnir0_Tiny-LLM-custom-default-f16-cuda-op21.onnx"
39+
)
40+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
41+
42+
print(f"-- load {onnx1!r}")
43+
onx1 = onnx.load(onnx1)
44+
print(f"-- load {onnx2!r}")
45+
onx2 = onnx.load(onnx2)
46+
47+
print(f"-- getting inputs for model_id {model_id!r}")
48+
data = get_untrained_model_with_inputs(model_id)
49+
inputs = data["inputs"]
50+
print(f"-- inputs: {string_type(inputs, with_shape=True)}")
51+
flatten_inputs = flatten_object(inputs, drop_keys=True)
52+
print(f"-- flat inputs: {string_type(flatten_inputs, with_shape=True)}")
53+
54+
names = [i.name for i in onx1.graph.input]
55+
itypes = [i.type.tensor_type.elem_type for i in onx1.graph.input]
56+
assert names == [
57+
i.name for i in onx2.graph.input
58+
], f"Not the same names for both models {names} != {[i.name for i in onx2.graph.input]}"
59+
feeds = {
60+
n: t.numpy().astype(tensor_dtype_to_np_dtype(itype))
61+
for n, itype, t in zip(names, itypes, flatten_inputs)
62+
}
63+
print(f"-- feeds: {string_type(feeds, with_shape=True)}")
64+
65+
print(f"-- creating session 1 from {onnx1!r}")
66+
opts = onnxruntime.SessionOptions()
67+
opts.optimized_model_filepath = "debug1_full.onnx"
68+
opts.log_severity_level = 0
69+
opts.log_verbosity_level = 0
70+
sess1 = onnxruntime.InferenceSession(onnx1, opts, providers=providers)
71+
print(f"-- creating session 2 from {onnx2!r}")
72+
opts.optimized_model_filepath = "debug2_full.onnx"
73+
opts.log_severity_level = 0
74+
opts.log_verbosity_level = 0
75+
sess2 = onnxruntime.InferenceSession(onnx2, opts, providers=providers)
76+
77+
print("-- run session1")
78+
expected1 = sess1.run(None, feeds)
79+
print(f"-- got {string_type(expected1, with_shape=True)}")
80+
print("-- run session2")
81+
expected2 = sess2.run(None, feeds)
82+
print(f"-- got {string_type(expected2, with_shape=True)}")
83+
84+
print("-- compute differences")
85+
diff = max_diff(expected1, expected2)
86+
print(f"-- diff={string_diff(diff)}")
87+
88+
89+
def get_names(onx: onnx.ModelProto) -> list[str]:
90+
names = []
91+
for node in onx.graph.node:
92+
for o in node.output:
93+
names.append((o, node.op_type, node.name))
94+
return names
95+
96+
97+
if diff["abs"] > 0.1:
98+
print("--")
99+
print("-- import select_model_inputs_outputs")
100+
from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs
101+
102+
print("-- looking into intermediate results")
103+
names1 = get_names(onx1)
104+
names2 = get_names(onx1)
105+
common = [n for n in names1 if n in (set(names1) & set(names2))]
106+
print(f"-- {len(common)} names / {len(names1)}-{len(names2)}")
107+
print(f"-- first names {common[:5]}")
108+
for name, op_type, op_name in common:
109+
x1 = select_model_inputs_outputs(onx1, [name])
110+
x2 = select_model_inputs_outputs(onx2, [name])
111+
s1 = onnxruntime.InferenceSession(x1.SerializeToString(), providers=providers)
112+
s2 = onnxruntime.InferenceSession(x2.SerializeToString(), providers=providers)
113+
e1 = s1.run(None, feeds)
114+
e2 = s2.run(None, feeds)
115+
diff = max_diff(e1, e2)
116+
print(
117+
f"-- name={name!r}: diff={string_diff(diff)} "
118+
f"- op_type={op_type!r}, op_name={op_name!r}"
119+
)
120+
if diff["abs"] > 0.1:
121+
opts = onnxruntime.SessionOptions()
122+
opts.optimized_model_filepath = "debug1.onnx"
123+
onnxruntime.InferenceSession(x1.SerializeToString(), opts, providers=providers)
124+
opts.optimized_model_filepath = "debug2.onnx"
125+
onnxruntime.InferenceSession(x2.SerializeToString(), opts, providers=providers)
126+
print("--")
127+
print("-- break here")
128+
print(f"-- feeds {string_type(feeds, with_shape=True)}")
129+
print(f"-- e1={string_type(e1, with_shape=True, with_min_max=True)}")
130+
print(f"-- e2={string_type(e2, with_shape=True, with_min_max=True)}")
131+
break
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import unittest
2+
import torch
3+
import transformers
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
hide_stdout,
7+
requires_diffusers,
8+
requires_torch,
9+
requires_transformers,
10+
)
11+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
12+
from onnx_diagnostic.torch_export_patches import torch_export_patches
13+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
14+
15+
16+
class TestTasksImageToVideo(ExtTestCase):
17+
@hide_stdout()
18+
@requires_diffusers("0.35")
19+
@requires_transformers("4.55")
20+
@requires_torch("2.8.99")
21+
def test_image_to_video(self):
22+
kwargs = {
23+
"_diffusers_version": "0.34.0.dev0",
24+
"_class_name": "CosmosTransformer3DModel",
25+
"max_size": [128, 240, 240],
26+
"text_embed_dim": 128,
27+
"use_cache": True,
28+
"in_channels": 3,
29+
"out_channels": 16,
30+
"num_layers": 2,
31+
"model_type": "dia",
32+
"patch_size": [1, 2, 2],
33+
"rope_scale": [1.0, 3.0, 3.0],
34+
"attention_head_dim": 16,
35+
"mlp_ratio": 0.4,
36+
"initializer_range": 0.02,
37+
"num_attention_heads": 16,
38+
"is_encoder_decoder": True,
39+
"adaln_lora_dim": 16,
40+
"concat_padding_mask": True,
41+
"extra_pos_embed_type": None,
42+
}
43+
config = transformers.DiaConfig(**kwargs)
44+
mid = "nvidia/Cosmos-Predict2-2B-Video2World"
45+
data = get_untrained_model_with_inputs(
46+
mid,
47+
verbose=1,
48+
add_second_input=True,
49+
subfolder="transformer",
50+
config=config,
51+
inputs_kwargs=dict(image_height=8 * 50, image_width=8 * 80),
52+
)
53+
self.assertEqual(data["task"], "image-to-video")
54+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
55+
model(**inputs)
56+
model(**data["inputs2"])
57+
with torch.fx.experimental._config.patch(
58+
backed_size_oblivious=True
59+
), torch_export_patches(
60+
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
61+
):
62+
torch.export.export(
63+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
64+
)
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser:
474474
)
475475
parser.add_argument(
476476
"--runtime",
477-
choices=["onnxruntime", "torch", "ref"],
477+
choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
478478
default="onnxruntime",
479479
help="onnx runtime to use, `onnxruntime` by default",
480480
)
@@ -542,6 +542,12 @@ def get_parser_validate() -> ArgumentParser:
542542
"the onnx exporter should use.",
543543
default="",
544544
)
545+
parser.add_argument(
546+
"--ort-logs",
547+
default=False,
548+
action=BooleanOptionalAction,
549+
help="Enables onnxruntime logging when the session is created",
550+
)
545551
return parser
546552

547553

@@ -601,6 +607,7 @@ def _cmd_validate(argv: List[Any]):
601607
repeat=args.repeat,
602608
warmup=args.warmup,
603609
inputs2=args.inputs2,
610+
ort_logs=args.ort_logs,
604611
output_names=(
605612
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
606613
),

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
import transformers
55
import transformers.cache_utils
66

7-
try:
8-
from transformers.models.mamba.modeling_mamba import MambaCache
9-
except ImportError:
10-
from transformers.cache_utils import MambaCache
11-
127

138
class CacheKeyValue:
149
"""
@@ -354,8 +349,15 @@ def make_encoder_decoder_cache(
354349
)
355350

356351

357-
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
352+
def make_mamba_cache(
353+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
354+
) -> "MambaCache": # noqa: F821
358355
"Creates a ``MambaCache``."
356+
# import is moved here because this part is slow.
357+
try:
358+
from transformers.models.mamba.modeling_mamba import MambaCache
359+
except ImportError:
360+
from transformers.cache_utils import MambaCache
359361
dtype = key_value_pairs[0][0].dtype
360362

361363
class _config:

onnx_diagnostic/tasks/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
fill_mask,
66
image_classification,
77
image_text_to_text,
8+
image_to_video,
9+
mask_generation,
810
mixture_of_expert,
911
object_detection,
1012
sentence_similarity,
@@ -14,7 +16,6 @@
1416
text_to_image,
1517
text2text_generation,
1618
zero_shot_image_classification,
17-
mask_generation,
1819
)
1920

2021
__TASKS__ = [
@@ -23,6 +24,8 @@
2324
fill_mask,
2425
image_classification,
2526
image_text_to_text,
27+
image_to_video,
28+
mask_generation,
2629
mixture_of_expert,
2730
object_detection,
2831
sentence_similarity,
@@ -32,7 +35,6 @@
3235
text_to_image,
3336
text2text_generation,
3437
zero_shot_image_classification,
35-
mask_generation,
3638
]
3739

3840

0 commit comments

Comments
 (0)