Skip to content

Commit 6fea147

Browse files
committed
fix torch.export 0/1 specializing
1 parent f15360e commit 6fea147

File tree

4 files changed

+40
-25
lines changed

4 files changed

+40
-25
lines changed

_unittests/ut_tasks/test_tasks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def test_text_generation(self):
4343
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
4444
model(**inputs)
4545
model(**data["inputs2"])
46-
with torch_export_patches(patch_transformers=True, verbose=10):
46+
with torch_export_patches(
47+
patch_transformers=True, verbose=10
48+
), torch.fx.experimental._config.patch(backed_size_oblivious=True):
4749
torch.export.export(
4850
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4951
)

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,9 @@ def test_phi2_export_module(self):
307307
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
308308
)
309309

310-
with torch_export_patches(patch_transformers=True):
310+
with torch_export_patches(
311+
patch_transformers=True
312+
), torch.fx.experimental._config.patch(backed_size_oblivious=True):
311313
ep = torch.export.export(
312314
model,
313315
(),
@@ -346,7 +348,9 @@ def test_phi2_export_interpreter(self):
346348
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
347349
)
348350

349-
with torch_export_patches(patch_transformers=True, verbose=1):
351+
with torch_export_patches(
352+
patch_transformers=True, verbose=1
353+
), torch.fx.experimental._config.patch(backed_size_oblivious=True):
350354
if masking_utils is not None:
351355
self.assertEqual(
352356
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],

onnx_diagnostic/tasks/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def get_inputs(
221221
"input_ids": {0: batch, 1: seq_length},
222222
"attention_mask": {
223223
0: batch,
224-
1: "cache+seq", # past_seq_length + seq_length
224+
1: "past_seq_length+seq_length", # past_seq_length + seq_length
225225
},
226226
"position_ids": {
227227
0: batch,

onnx_diagnostic/torch_models/validate.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,16 +1105,24 @@ def call_torch_export_export(
11051105
print("[call_torch_export_export] export...")
11061106

11071107
model = data["model"]
1108+
1109+
def _run_torch_export():
1110+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
1111+
ep = torch.export.export(
1112+
model,
1113+
args,
1114+
kwargs=kwargs,
1115+
dynamic_shapes=dse,
1116+
strict=strict,
1117+
)
1118+
return ep
1119+
11081120
ep = _quiet_or_not_quiet(
11091121
quiet,
11101122
"export_export",
11111123
summary,
11121124
data,
1113-
(
1114-
lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: (
1115-
torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s)
1116-
)
1117-
),
1125+
_run_torch_export,
11181126
)
11191127
if "ERR_export_export" in summary:
11201128
return summary, data
@@ -1715,23 +1723,24 @@ def call_torch_export_custom(
17151723
kws["target_opset"] = opset
17161724
if output_names:
17171725
kws["output_names"] = output_names
1718-
1719-
epo, opt_stats = _quiet_or_not_quiet(
1720-
quiet,
1721-
"export_export_onnx_c",
1722-
summary,
1723-
data,
1724-
(
1725-
lambda m=model, args=args, kwargs=kwargs, kws=kws: (
1726-
to_onnx(
1727-
model,
1728-
args,
1729-
kwargs=kwargs,
1730-
**kws,
1726+
# anti-specializing 0/1 during torch.export.export
1727+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
1728+
epo, opt_stats = _quiet_or_not_quiet(
1729+
quiet,
1730+
"export_export_onnx_c",
1731+
summary,
1732+
data,
1733+
(
1734+
lambda m=model, args=args, kwargs=kwargs, kws=kws: (
1735+
to_onnx(
1736+
model,
1737+
args,
1738+
kwargs=kwargs,
1739+
**kws,
1740+
)
17311741
)
1732-
)
1733-
),
1734-
)
1742+
),
1743+
)
17351744
if "ERR_export_onnx_c" in summary:
17361745
return summary, data
17371746

0 commit comments

Comments
 (0)