Skip to content

Commit 9a3d0cf

Browse files
committed
fix dynamo
2 parents 116b59c + 283b2cd commit 9a3d0cf

File tree

7 files changed

+237
-2
lines changed

7 files changed

+237
-2
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Change Logs
55
++++++
66

77
* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs
8+
* :pr:`269`: adds one unit test to track a patch fixing broadcast output shape
89
* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)
910
* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches
1011

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
Dynamic Shapes and Broadcasting
3+
===============================
4+
5+
:func:`torch.export.export` makes strict assumption on dynamic shapes
6+
to the generic case. Let's consider two tensors with only one dimension.
7+
``x * y`` allows four configurations:
8+
9+
* ``shape(x) = (1,)`` and ``shape(y) = (1,)``
10+
* ``shape(x) = (1,)`` and ``shape(y) = (p,)``
11+
* ``shape(x) = (q,)`` and ``shape(y) = (1,)``
12+
* ``shape(x) = (p,)`` and ``shape(y) = (p,)``
13+
14+
The expected shape for ``shape(x * y)`` is ``(max(p,q),)``.
15+
16+
Simple Case
17+
+++++++++++
18+
19+
"""
20+
21+
import torch
22+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
23+
from torch._subclasses.fake_tensor import FakeTensorMode
24+
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
25+
from onnx_diagnostic.torch_export_patches import torch_export_patches
26+
from torch.fx import Tracer
27+
28+
29+
class Model(torch.nn.Module):
30+
def forward(self, x, y):
31+
return x * y
32+
33+
34+
Dim = torch.export.Dim
35+
36+
ep = torch.export.export(
37+
Model(),
38+
(torch.tensor([2, 3], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)),
39+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
40+
)
41+
print(ep)
42+
43+
# %%
44+
# We see clearly that the export assumed that ``x`` ad ``y`` had the same shape.
45+
# No other configuration seemed to work at export time,
46+
# including ``with torch.fx.experimental._config.patch(backed_size_oblivious=True):``
47+
# the shape of one tensor equal to ``(1,)``.
48+
49+
output = [n for n in ep.graph.nodes if n.op == "output"][0]
50+
print("output is ", output.name, " arg is", output.args[0])
51+
52+
# %%
53+
# The final shape is:
54+
55+
shape = output.args[0][0].meta["val"].shape
56+
print("output shape is ", shape)
57+
58+
# %%
59+
# Tracing
60+
# +++++++
61+
#
62+
# Let's compare with what a simple tracing would do. Let's use :class:`torch.fx.Tracer`.
63+
64+
graph = Tracer().trace(Model())
65+
print(graph)
66+
67+
# %%
68+
output = [n for n in graph.nodes if n.op == "output"][0]
69+
print("output is ", output.name, " arg is", output.args[0])
70+
print("The tracer leaves no trace:", output.args[0].__dict__)
71+
72+
# %%
73+
# Shape propagation
74+
# +++++++++++++++++
75+
76+
gm = torch.fx.GraphModule(Model(), graph)
77+
78+
shape_env = ShapeEnv()
79+
fake_mode = FakeTensorMode(shape_env=shape_env)
80+
# d1 = shape_env.create_unbacked_symint()
81+
# d2 = shape_env.create_unbacked_symint()
82+
fake_inputs = fake_mode.from_tensor(
83+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
84+
), fake_mode.from_tensor(torch.zeros((2,), dtype=torch.float32), static_shapes=False)
85+
86+
print("fake_inputs are ", fake_inputs)
87+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
88+
print("output is", res)
89+
90+
# %%
91+
# Handle Different Shapes
92+
# +++++++++++++++++++++++
93+
94+
fake_inputs = fake_mode.from_tensor(
95+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
96+
), fake_mode.from_tensor(torch.zeros((1,), dtype=torch.float32), static_shapes=False)
97+
98+
print("fake_inputs are ", fake_inputs)
99+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
100+
print("output is", res)
101+
102+
# %%
103+
# Conclusion
104+
# ++++++++++
105+
#
106+
# We need to give distinct dimensions to get distinct names.
107+
108+
fake_inputs = fake_mode.from_tensor(
109+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
110+
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
111+
print("fake_inputs are ", fake_inputs)
112+
113+
114+
# %%
115+
try:
116+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
117+
except Exception as e:
118+
print(e)
119+
120+
# %%
121+
# By applying the patches:
122+
123+
with torch_export_patches():
124+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
125+
print("output is", res)
126+
127+
# %%
128+
# This is what we want. Let's go back to :func:`torch.export.export`
129+
130+
with torch_export_patches():
131+
ep = torch.export.export(
132+
Model(),
133+
(
134+
torch.tensor([2, 3], dtype=torch.float32),
135+
torch.tensor([2, 3, 4], dtype=torch.float32),
136+
),
137+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
138+
)
139+
print(ep)
140+
141+
# %%
142+
output = [n for n in ep.graph.nodes if n.op == "output"][0]
143+
print("output is ", output.name, " arg is", output.args[0])
144+
shape = output.args[0][0].meta["val"].shape
145+
print("output shape is ", shape)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
requires_transformers,
7+
requires_torch,
8+
)
9+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
10+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
11+
from onnx_diagnostic.torch_export_patches import torch_export_patches
12+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
13+
14+
15+
class TestTasksTextGeneration(ExtTestCase):
16+
@hide_stdout()
17+
@requires_transformers("4.53")
18+
@requires_torch("2.7.99")
19+
def test_image_text_to_text_gemma3_for_causallm(self):
20+
mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
21+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
22+
self.assertEqual(data["task"], "text-generation")
23+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
24+
model(**torch_deepcopy(inputs))
25+
model(**data["inputs2"])
26+
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
27+
torch.export.export(
28+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,25 @@ def forward(self, x, ind1, ind2):
491491
)
492492
self.assertEqualArray(expected, ep.module()(*inputs))
493493

494+
def test_broadcast_max(self):
495+
class Model(torch.nn.Module):
496+
def forward(self, x, y):
497+
return x * y
498+
499+
Dim = torch.export.Dim
500+
with torch_export_patches():
501+
ep = torch.export.export(
502+
Model(),
503+
(
504+
torch.tensor([2, 3], dtype=torch.float32),
505+
torch.tensor([2, 3, 4], dtype=torch.float32),
506+
),
507+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
508+
)
509+
output = [n for n in ep.graph.nodes if n.op == "output"]
510+
shape = output[0].args[0][0].meta["val"].shape
511+
self.assertEqual(str(shape), "torch.Size([Max(s17, s77)])")
512+
494513

495514
if __name__ == "__main__":
496515
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/code_sample.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def make_export_code(
132132
if opset:
133133
args.append(f"opset_version={opset}")
134134
sargs = ", ".join(args)
135-
imports = []
136135
code.extend([f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"])
137136
if optimization:
138137
imports.append("import onnxscript")

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4865,3 +4865,41 @@ def _ccached_google_gemma_3_4b_it_like():
48654865
},
48664866
}
48674867
)
4868+
4869+
4870+
def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
4871+
"hf-internal-testing/tiny-random-Gemma3ForCausalLM"
4872+
return transformers.Gemma3TextConfig(
4873+
**{
4874+
"architectures": ["Gemma3ForCausalLM"],
4875+
"attention_bias": false,
4876+
"attention_dropout": 0.0,
4877+
"attn_logit_softcapping": null,
4878+
"bos_token_id": 2,
4879+
"cache_implementation": "hybrid",
4880+
"eos_token_id": [1, 106],
4881+
"final_logit_softcapping": null,
4882+
"head_dim": 8,
4883+
"hidden_activation": "gelu_pytorch_tanh",
4884+
"hidden_size": 16,
4885+
"initializer_range": 0.02,
4886+
"intermediate_size": 32,
4887+
"max_position_embeddings": 32768,
4888+
"model_type": "gemma3_text",
4889+
"num_attention_heads": 2,
4890+
"num_hidden_layers": 2,
4891+
"num_key_value_heads": 1,
4892+
"pad_token_id": 0,
4893+
"query_pre_attn_scalar": 256,
4894+
"rms_norm_eps": 1e-06,
4895+
"rope_local_base_freq": 10000,
4896+
"rope_scaling": null,
4897+
"rope_theta": 1000000,
4898+
"sliding_window": 512,
4899+
"sliding_window_pattern": 6,
4900+
"torch_dtype": "float32",
4901+
"transformers_version": "4.52.0.dev0",
4902+
"use_cache": true,
4903+
"vocab_size": 262144,
4904+
}
4905+
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ select = [
150150
"_doc/notebooks/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
151151
"_doc/recipes/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
152152
"_scripts/compare_model_execution.py" = ["E402", "F401"]
153-
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
153+
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "RUF015", "SIM105", "SIM117"]
154154
"_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"]
155155
"onnx_diagnostic/export/__init__.py" = ["F401"]
156156
"onnx_diagnostic/helpers/__init__.py" = ["F401"]

0 commit comments

Comments
 (0)