Skip to content

Commit e7b9dc1

Browse files
committed
fix patches
1 parent ca25dc6 commit e7b9dc1

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

_doc/technical/plot_broadcast_export_issue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def forward(self, x, y):
8080
# d1 = shape_env.create_unbacked_symint()
8181
# d2 = shape_env.create_unbacked_symint()
8282
fake_inputs = fake_mode.from_tensor(
83-
torch.zeros((2,), dtype=torch.float32), static_shapes=False
84-
), fake_mode.from_tensor(torch.zeros((2,), dtype=torch.float32), static_shapes=False)
83+
torch.zeros((3,), dtype=torch.float32), static_shapes=False
84+
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
8585

8686
print("fake_inputs are ", fake_inputs)
8787
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
@@ -115,7 +115,7 @@ def forward(self, x, y):
115115
try:
116116
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
117117
except Exception as e:
118-
print(e)
118+
print("error", e)
119119

120120
# %%
121121
# By applying the patches:

onnx_diagnostic/_command_lines_parser.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -693,16 +693,16 @@ def _cmd_export_sample(argv: List[Any]):
693693
os.makedirs(args.dump_folder, exist_ok=True)
694694
name = (
695695
_make_folder_name(
696-
model_id=args.model_id,
697-
exporter=args.exporter,
698-
optimization=args.optimization,
696+
model_id=args.mid,
697+
exporter=args.export,
698+
optimization=args.opt,
699699
dtype=args.dtype,
700700
device=args.device,
701701
subfolder=args.subfolder,
702702
opset=args.opset,
703703
drop_inputs=None if not args.drop else args.drop.split(","),
704-
same_as_pretrained=args.same_as_pretrained,
705-
use_pretrained=args.use_pretrained,
704+
same_as_pretrained=args.same_as_trained,
705+
use_pretrained=args.trained,
706706
task=args.task,
707707
).replace("/", "-")
708708
+ ".py"

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,10 @@ def patched_sdpa_attention_forward(
13121312
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
13131313
is_causal = attention_mask is None and is_causal
13141314

1315+
torch._check(
1316+
attention_mask.shape[3] == key.shape[2],
1317+
"Attention mask shape incompatible with key shape.",
1318+
)
13151319
attn_output = torch.nn.functional.scaled_dot_product_attention(
13161320
query,
13171321
key,

0 commit comments

Comments
 (0)