Skip to content

Commit 78a024b

Browse files
authored
fix patches (#271)
* fix patches * fix patch * fixes shape information * fix issues
1 parent ca25dc6 commit 78a024b

File tree

6 files changed

+87
-59
lines changed

6 files changed

+87
-59
lines changed

_doc/technical/plot_broadcast_export_issue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def forward(self, x, y):
8080
# d1 = shape_env.create_unbacked_symint()
8181
# d2 = shape_env.create_unbacked_symint()
8282
fake_inputs = fake_mode.from_tensor(
83-
torch.zeros((2,), dtype=torch.float32), static_shapes=False
84-
), fake_mode.from_tensor(torch.zeros((2,), dtype=torch.float32), static_shapes=False)
83+
torch.zeros((3,), dtype=torch.float32), static_shapes=False
84+
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
8585

8686
print("fake_inputs are ", fake_inputs)
8787
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
@@ -115,7 +115,7 @@ def forward(self, x, y):
115115
try:
116116
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
117117
except Exception as e:
118-
print(e)
118+
print("error", e)
119119

120120
# %%
121121
# By applying the patches:

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TestTasksTextGeneration(ExtTestCase):
1616
@hide_stdout()
1717
@requires_transformers("4.53")
1818
@requires_torch("2.7.99")
19-
def test_image_text_to_text_gemma3_for_causallm(self):
19+
def test_tet_generation_gemma3_for_causallm(self):
2020
mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
2121
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
2222
self.assertEqual(data["task"], "text-generation")
@@ -28,6 +28,23 @@ def test_image_text_to_text_gemma3_for_causallm(self):
2828
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2929
)
3030

31+
@hide_stdout()
32+
@requires_transformers("4.53")
33+
@requires_torch("2.7.99")
34+
def test_itext_generation_phi_3_mini_128k_instruct(self):
35+
mid = "microsoft/Phi-3-mini-128k-instruct"
36+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
37+
self.assertEqual(data["task"], "text-generation")
38+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
39+
print("--", self.string_type(inputs, with_shape=True))
40+
print("--", self.string_type(ds))
41+
model(**torch_deepcopy(inputs))
42+
model(**data["inputs2"])
43+
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
44+
torch.export.export(
45+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
46+
)
47+
3148

3249
if __name__ == "__main__":
3350
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -371,30 +371,34 @@ def __call__(self, parser, namespace, values, option_string=None):
371371
setattr(namespace, self.dest, d)
372372

373373

374-
def get_parser_validate() -> ArgumentParser:
374+
def get_parser_validate(name: str = "validate") -> ArgumentParser:
375375
parser = ArgumentParser(
376-
prog="validate",
376+
prog=name,
377377
description=textwrap.dedent(
378378
"""
379-
Prints out dummy inputs for a particular task or a model id.
380-
If both mid and task are empty, the command line displays the list
381-
of supported tasks.
379+
Validates a model for a particular task given the model id.
380+
It exports the model and then validates it by computing the discrepancies
381+
on different input sets.
382+
"""
383+
if name == "validate"
384+
else """
385+
Creates a script to export a model for a particular task given the model id.
382386
"""
383387
),
384388
epilog=textwrap.dedent(
385-
"""
389+
f"""
386390
If the model id is specified, one untrained version of it is instantiated.
387391
Examples:
388392
389-
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
393+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
390394
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
391395
--dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
392396
393-
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
397+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
394398
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
395399
--dtype float16 --device cuda --patch --export custom --opt default
396400
397-
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
401+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
398402
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
399403
--dtype float16 --device cuda --export modelbuilder
400404
@@ -405,12 +409,12 @@ def get_parser_validate() -> ArgumentParser:
405409
The behaviour may be modified compare the original configuration,
406410
the following argument can be rope_scaling to dynamic:
407411
408-
--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
412+
--mop \"rope_scaling={{'rope_type': 'dynamic', 'factor': 10.0}}\""
409413
410414
You can profile the command line by running:
411415
412-
pyinstrument -m onnx_diagnostic validate ...
413-
pyinstrument -r html -o profile.html -m onnx_diagnostic validate ...
416+
pyinstrument -m onnx_diagnostic {name} ...
417+
pyinstrument -r html -o profile.html -m onnx_diagnostic {name} ...
414418
"""
415419
),
416420
formatter_class=RawTextHelpFormatter,
@@ -460,19 +464,19 @@ def get_parser_validate() -> ArgumentParser:
460464
"--same-as-trained",
461465
default=False,
462466
action=BooleanOptionalAction,
463-
help="Validates a model identical to the trained model but not trained.",
467+
help="Validates or exports a model identical to the trained model but not trained.",
464468
)
465469
parser.add_argument(
466470
"--trained",
467471
default=False,
468472
action=BooleanOptionalAction,
469-
help="Validates the trained model (requires downloading).",
473+
help="Validates or exports the trained model (requires downloading).",
470474
)
471475
parser.add_argument(
472476
"--inputs2",
473477
default=1,
474478
type=int,
475-
help="Validates the model on a second set of inputs\n"
479+
help="Validates or exports the model on a second set of inputs\n"
476480
"to check the exported model supports dynamism. The values is used "
477481
"as an increment to the first set of inputs. A high value may trick "
478482
"a different behavior in the model and missed by the exporter.",
@@ -504,13 +508,14 @@ def get_parser_validate() -> ArgumentParser:
504508
"--subfolder",
505509
help="Subfolder where to find the model and the configuration.",
506510
)
507-
parser.add_argument(
508-
"--ortfusiontype",
509-
required=False,
510-
help="Applies onnxruntime fusion, this parameter should contain the\n"
511-
"model type or multiple values separated by `|`. `ALL` can be used\n"
512-
"to run them all.",
513-
)
511+
if name == "validate":
512+
parser.add_argument(
513+
"--ortfusiontype",
514+
required=False,
515+
help="Applies onnxruntime fusion, this parameter should contain the\n"
516+
"model type or multiple values separated by `|`. `ALL` can be used\n"
517+
"to run them all.",
518+
)
514519
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
515520
parser.add_argument("--dtype", help="Changes dtype if necessary.")
516521
parser.add_argument("--device", help="Changes the device if necessary.")
@@ -532,33 +537,38 @@ def get_parser_validate() -> ArgumentParser:
532537
"--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
533538
action=_ParseDict,
534539
)
535-
parser.add_argument(
536-
"--repeat",
537-
default=1,
538-
type=int,
539-
help="number of times to run the model to measures inference time",
540-
)
541-
parser.add_argument(
542-
"--warmup", default=0, type=int, help="number of times to run the model to do warmup"
543-
)
540+
if name == "validate":
541+
parser.add_argument(
542+
"--repeat",
543+
default=1,
544+
type=int,
545+
help="number of times to run the model to measures inference time",
546+
)
547+
parser.add_argument(
548+
"--warmup",
549+
default=0,
550+
type=int,
551+
help="number of times to run the model to do warmup",
552+
)
544553
parser.add_argument(
545554
"--outnames",
546555
help="This comma separated list defines the output names "
547556
"the onnx exporter should use.",
548557
default="",
549558
)
550-
parser.add_argument(
551-
"--ort-logs",
552-
default=False,
553-
action=BooleanOptionalAction,
554-
help="Enables onnxruntime logging when the session is created",
555-
)
556-
parser.add_argument(
557-
"--quiet-input-sets",
558-
default="",
559-
help="Avoids raising an exception when an input sets does not work with "
560-
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
561-
)
559+
if name == "validate":
560+
parser.add_argument(
561+
"--ort-logs",
562+
default=False,
563+
action=BooleanOptionalAction,
564+
help="Enables onnxruntime logging when the session is created",
565+
)
566+
parser.add_argument(
567+
"--quiet-input-sets",
568+
default="",
569+
help="Avoids raising an exception when an input sets does not work with "
570+
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
571+
)
562572
return parser
563573

564574

@@ -637,7 +647,7 @@ def _cmd_export_sample(argv: List[Any]):
637647
from .torch_models.code_sample import code_sample
638648
from .tasks import supported_tasks
639649

640-
parser = get_parser_validate()
650+
parser = get_parser_validate("exportsample")
641651
args = parser.parse_args(argv[1:])
642652
if not args.task and not args.mid:
643653
print("-- list of supported tasks:")
@@ -693,16 +703,16 @@ def _cmd_export_sample(argv: List[Any]):
693703
os.makedirs(args.dump_folder, exist_ok=True)
694704
name = (
695705
_make_folder_name(
696-
model_id=args.model_id,
697-
exporter=args.exporter,
698-
optimization=args.optimization,
706+
model_id=args.mid,
707+
exporter=args.export,
708+
optimization=args.opt,
699709
dtype=args.dtype,
700710
device=args.device,
701711
subfolder=args.subfolder,
702712
opset=args.opset,
703713
drop_inputs=None if not args.drop else args.drop.split(","),
704-
same_as_pretrained=args.same_as_pretrained,
705-
use_pretrained=args.use_pretrained,
714+
same_as_pretrained=args.same_as_trained,
715+
use_pretrained=args.trained,
706716
task=args.task,
707717
).replace("/", "-")
708718
+ ".py"
@@ -1111,7 +1121,7 @@ def main(argv: Optional[List[Any]] = None):
11111121
validate=get_parser_validate,
11121122
stats=get_parser_stats,
11131123
agg=get_parser_agg,
1114-
exportsample=get_parser_validate,
1124+
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
11151125
)
11161126
cmd = argv[0]
11171127
if cmd not in parsers:

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def get_inputs_default(
271271
"input_ids": {0: batch, 1: seq_length},
272272
"token_type_ids": {0: batch, 1: seq_length},
273273
"attention_mask": {0: batch, 1: "cache+seq"},
274-
"position_ids": {0: batch, 1: "cache+seq"},
274+
"position_ids": {0: batch, 1: seq_length},
275275
"past_key_values": [
276276
[{0: batch} for _ in range(num_hidden_layers)],
277277
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],

onnx_diagnostic/tasks/text_generation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,7 @@ def get_inputs(
220220
0: batch,
221221
1: "cache+seq", # cache_length + seq_length
222222
},
223-
"position_ids": {
224-
0: batch,
225-
1: "cache+seq", # cache_length + seq_length
226-
},
223+
"position_ids": {0: batch, 1: seq_length},
227224
"past_key_values": [
228225
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
229226
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,10 @@ def patched_sdpa_attention_forward(
13121312
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
13131313
is_causal = attention_mask is None and is_causal
13141314

1315+
torch._check(
1316+
attention_mask is None or attention_mask.shape[3] == key.shape[2],
1317+
"Attention mask shape incompatible with key shape.",
1318+
)
13151319
attn_output = torch.nn.functional.scaled_dot_product_attention(
13161320
query,
13171321
key,

0 commit comments

Comments
 (0)