Skip to content

Commit 0478a59

Browse files
committed
payvhrd'
1 parent 4afcc1b commit 0478a59

File tree

5 files changed

+223
-17
lines changed

5 files changed

+223
-17
lines changed

.github/workflows/ci.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jobs:
1717
os: [ubuntu-latest]
1818
python: ['3.11', '3.12']
1919
transformers: ['4.48', '4.50', 'main']
20+
torch: ['2.6', 'main']
2021

2122
steps:
2223
- uses: actions/checkout@v3
@@ -26,7 +27,13 @@ jobs:
2627
python-version: ${{ matrix.python }}
2728

2829
- name: Install pytorch
29-
run: python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
30+
run: |
31+
if [[ "${{ matrix.torch }}" == "main" ]]; then
32+
python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
33+
else
34+
echo "install torch==${{ matrix.torch }}"
35+
pip install torch==${{ matrix.torch }}
36+
fi
3037
3138
- name: Install transformers ${{ matrix.transformers }}
3239
run: |

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,23 @@ def test_get_phi2(self):
1717
@requires_transformers("4.52")
1818
def test_export_phi2_1(self):
1919
data = get_phi2(num_hidden_layers=2)
20-
model, inputs = data["model"], data["inputs"]
20+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2121
self.assertEqual(
2222
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
2323
)
24-
ep = torch.export.export(
25-
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
26-
)
24+
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds)
2725
assert ep
2826

2927
@ignore_warnings(UserWarning)
3028
def test_export_phi2_2_bypassed(self):
3129
data = get_phi2(num_hidden_layers=2)
32-
model, inputs = data["model"], data["inputs"]
30+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
3331
self.assertEqual(
3432
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3533
)
3634
with bypass_export_some_errors(patch_transformers=True) as modificator:
3735
inputs = modificator(inputs)
38-
ep = torch.export.export(
39-
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], strict=False
40-
)
36+
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
4137
assert ep
4238

4339

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import inspect
23
import unittest
34
import torch
@@ -57,19 +58,41 @@ def test_onnx_export_tiny_llm_xdbg(self):
5758

5859
@ignore_warnings((UserWarning, DeprecationWarning, FutureWarning))
5960
@hide_stdout()
60-
def test_bypass_onnx_export_tiny_llm_official(self):
61+
def test_bypass_onnx_export_tiny_llm_official_nopositionids(self):
6162
data = get_tiny_llm()
62-
model, inputs = data["model"], data["inputs"]
63+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
64+
del inputs["position_ids"]
65+
del ds["position_ids"]
66+
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
67+
with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator:
68+
new_inputs = modificator(copy.deepcopy(inputs))
69+
ep = torch.onnx.export(
70+
model,
71+
(),
72+
kwargs=new_inputs,
73+
dynamic_shapes=ds,
74+
dynamo=True,
75+
optimize=True,
76+
)
77+
self.assert_onnx_disc(
78+
inspect.currentframe().f_code.co_name, ep.model_proto, model, inputs, verbose=1
79+
)
80+
81+
@ignore_warnings((UserWarning, DeprecationWarning, FutureWarning))
82+
@hide_stdout()
83+
def test_bypass_onnx_export_tiny_llm_official_full(self):
84+
data = get_tiny_llm()
85+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
6386
self.assertEqual(
6487
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
6588
)
6689
with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator:
67-
new_inputs = modificator(inputs)
90+
new_inputs = modificator(copy.deepcopy(inputs))
6891
ep = torch.onnx.export(
6992
model,
7093
(),
7194
kwargs=new_inputs,
72-
dynamic_shapes=data["dynamic_shapes"],
95+
dynamic_shapes=ds,
7396
dynamo=True,
7497
optimize=True,
7598
)
@@ -82,7 +105,7 @@ def test_bypass_onnx_export_tiny_llm_official(self):
82105
@hide_stdout()
83106
def test_bypass_onnx_export_tiny_llm_xdbg(self):
84107
data = get_tiny_llm()
85-
model, inputs = data["model"], data["inputs"]
108+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
86109
self.assertEqual(
87110
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
88111
)
@@ -92,7 +115,7 @@ def test_bypass_onnx_export_tiny_llm_xdbg(self):
92115
model,
93116
(),
94117
kwargs=new_inputs,
95-
dynamic_shapes=data["dynamic_shapes"],
118+
dynamic_shapes=ds,
96119
verbose=1,
97120
export_options=ExportOptions(strict=False),
98121
)

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,183 @@ def patched__broadcast_shapes(*_shapes):
146146
common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
147147

148148
return common_shape
149+
150+
151+
class patched_ShapeEnv:
152+
153+
def _set_replacement(
154+
self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821
155+
) -> None:
156+
"""
157+
Adds or updates a replacement for a symbol.
158+
Use this instead of `self.replacements[a] = tgt`.
159+
"""
160+
if tgt == self.replacements.get(a, None):
161+
return
162+
163+
if a in tgt.free_symbols:
164+
return
165+
166+
import sympy
167+
from torch._logging import structured
168+
from torch.utils._traceback import CapturedTraceback
169+
from torch._logging import trace_structured
170+
from torch._guards import TracingContext
171+
from torch.utils._sympy.functions import FloorToInt, CeilToInt
172+
from torch.utils._sympy.solve import try_solve
173+
from torch.fx.experimental.symbolic_shapes import (
174+
_is_supported_equivalence,
175+
ValueRanges,
176+
)
177+
178+
# Precondition: a == tgt
179+
assert isinstance(a, sympy.Symbol)
180+
181+
if self.allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt):
182+
# continuing leads to placeholder shapes
183+
# having complex expressions that we can't resolve
184+
return
185+
186+
# Handles nested tensor symbolic variables which don't have
187+
# var_to_range bounds
188+
tgt_bound = None
189+
if a in self.var_to_range:
190+
src_bound = self.var_to_range[a]
191+
192+
# First, refine the value range of a based on the computed value range
193+
# of tgt. This is always OK to do, even if we decide not to do the
194+
# substitution in the end. This might be a no-op, if a already has
195+
# a tighter bound
196+
tgt_bound = self.bound_sympy(tgt)
197+
self._update_var_to_range(a, tgt_bound)
198+
199+
# Next, check if we can update the range of free symbols in tgt
200+
# based on the range in a. But only do it if:
201+
# - the source bound non-trivially improves over what we get out of
202+
# the existing bounds.
203+
# - the replacement is univariate and we can invert the tgt expression
204+
if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
205+
b = next(iter(tgt.free_symbols))
206+
# Try to invert the equality
207+
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
208+
if r is not None:
209+
self.log.debug(
210+
"set_replacement: solve for %s in %s == %s gives %s",
211+
b,
212+
a,
213+
tgt,
214+
r,
215+
)
216+
# The solution here can be non-integral, for example, if
217+
# we have s0 = 2*s1, then s1 = s0/2. What we would like
218+
# to do is calculated the bounds in arbitrary precision,
219+
# and then requantize the bound to integers when we are
220+
# done.
221+
rat_b_bound = self.bound_sympy(r[1])
222+
b_bound = ValueRanges(
223+
CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)
224+
)
225+
self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
226+
tgt_bound = self.bound_sympy(tgt)
227+
assert tgt_bound.issubset(
228+
src_bound
229+
), f"{tgt_bound=} not a subset of {src_bound=}"
230+
231+
# TODO: Should we propagate size-like-ness?
232+
#
233+
# Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
234+
# to become size-like.
235+
#
236+
# Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T
237+
# propagate in this case, because what if u0 == 0, then u1 is negative
238+
# and clearly isn't a size. So, at minimum, any f(x) whose value
239+
# range isn't [0, inf] given x in [0, inf] cannot propagate
240+
# size-like-ness. But there are many situations where you could
241+
# imagine u1 is going to be size-like and actually you just didn't
242+
# have a refined enough value range on u0. Since even innocuous
243+
# looking arithmetic operations can destroy size-like-ness, it's
244+
# best to not propagate it at all and force the user to annotate it
245+
# as necessary.
246+
#
247+
# Compromise: we preserve size-like-ness only for exact equality
248+
# and nothing else.
249+
if a in self.size_like and isinstance(tgt, sympy.Symbol):
250+
self.size_like.add(tgt)
251+
elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
252+
self.size_like.add(a)
253+
254+
# Now, decide if we will do the substitution.
255+
#
256+
# - If the source has a non-trivial range, only substitute if
257+
# we preserve this range. Note that we may have propagated
258+
# the src_range to free variables in tgt when tgt is univariate
259+
# and we could find an inverse, which helps us achieve this.
260+
# This ensures we never "forget" about user defined ranges,
261+
# even if they end up being defined on composite formulas
262+
# like s0 + s1.
263+
#
264+
# - If the variable is unbacked, only substitute if the substitution
265+
# would preserve the bounds also under size-like-ness conditions.
266+
267+
if not tgt_bound.issubset(src_bound):
268+
self.log.debug(
269+
"skipped set_replacement %s = %s (%s) [%s not subset of %s]",
270+
a,
271+
tgt,
272+
msg,
273+
tgt_bound,
274+
src_bound,
275+
)
276+
return
277+
elif a in self.size_like:
278+
tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
279+
src_bound_so = self.bound_sympy(a, size_oblivious=True)
280+
if not tgt_bound_so.issubset(src_bound_so):
281+
self.log.debug(
282+
"skipped set_replacement %s = %s (%s) "
283+
"[%s not subset of %s (size-oblivious conditions)]",
284+
a,
285+
tgt,
286+
msg,
287+
tgt_bound_so,
288+
src_bound_so,
289+
)
290+
return
291+
292+
if isinstance(tgt, (sympy.Integer, sympy.Float)):
293+
# specializing to a constant, which is likely unexpected (unless
294+
# you specified dynamic=True)
295+
296+
user_tb = TracingContext.extract_stack()
297+
trace_structured(
298+
"symbolic_shape_specialization",
299+
metadata_fn=lambda: {
300+
"symbol": repr(a),
301+
"sources": [s.name() for s in self.var_to_sources.get(a, [])],
302+
"value": repr(tgt),
303+
"reason": msg,
304+
"stack": structured.from_traceback(
305+
CapturedTraceback.extract(skip=1).summary()
306+
),
307+
"user_stack": (structured.from_traceback(user_tb) if user_tb else None),
308+
},
309+
)
310+
311+
# if config.print_specializations:
312+
# self.log.warning(
313+
# "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
314+
# )
315+
# self.log.debug("SPECIALIZATION", stack_info=True)
316+
assert msg != "range_refined_to_singleton", f"{[a, tgt, msg, tgt_bound]}"
317+
# log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
318+
self.replacements[a] = tgt
319+
# NB: the replacement may get refined, but the user will find the
320+
# FIRST one most useful (TODO: Maybe we could consider tracking all of
321+
# them)
322+
if a not in self.replacements_slocs:
323+
self.replacements_slocs[a] = self._get_sloc()
324+
self._update_version_counter()
325+
326+
# When specializing 'a == tgt', the equality should be also conveyed to
327+
# Z3, in case an expression uses 'a'.
328+
self._add_target_expr(sympy.Eq(a, tgt, evaluate=False))

onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ def get_tiny_llm(
6464

6565
shapes = {
6666
"input_ids": {0: batch, 1: seq_length},
67-
"position_ids": {
67+
"attention_mask": {
6868
0: batch,
6969
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
7070
},
71-
"attention_mask": {
71+
"position_ids": {
7272
0: batch,
7373
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
7474
},

0 commit comments

Comments
 (0)