Skip to content

Commit 3dd887a

Browse files
committed
fix patch
1 parent dc11cfa commit 3dd887a

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@
8888
# Shapes may not match on the second call with the modified inputs.
8989

9090

91-
with torch_export_patches(patch_transformers=True):
91+
with torch_export_patches(patch_transformers=True), torch.fx.experimental._config.patch(
92+
backed_size_oblivious=True
93+
):
9294

9395
# Two unnecessary steps but useful in case of an error
9496
# We check the cache is registered.

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,13 +1672,9 @@ def patched_sdpa_attention_forward(
16721672
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # noqa: E501
16731673
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` # noqa: E501
16741674
if is_causal is None:
1675-
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag # noqa: E501
1676-
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns # noqa: E501
1677-
# is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # noqa: E501
1678-
# NOTE: query.shape[2] == 1 or > 1 should have the same output for causal attention
1679-
# so we simplify the condition to:
1680-
is_causal = attention_mask is None and getattr(module, "is_causal", True)
1681-
1675+
# NOTE: attention_mask should always be not None
1676+
# https://github.com/huggingface/transformers/blob/def4a37e19601b597f170e81684c8b0b5f84db39/src/transformers/masking_utils.py#L240-L243
1677+
is_causal = False
16821678
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # noqa: E501
16831679
# We convert it to a bool for the SDPA kernel that only accepts bools.
16841680
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):

0 commit comments

Comments
 (0)