Skip to content

Commit 5a4c01c

Browse files
committed
fix patch for other version of transformers
1 parent f06d479 commit 5a4c01c

File tree

4 files changed

+45
-8
lines changed

4 files changed

+45
-8
lines changed

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_text_generation_phi_3_mini_128k_instruct(self):
4848

4949
@hide_stdout()
5050
@requires_transformers("4.53")
51-
@requires_torch("2.7.99")
51+
@requires_torch("2.8.99") # check_guards not supported
5252
def test_text_generation_tiny_llm(self):
5353
mid = "arnir0/Tiny-LLM"
5454
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import transformers
44
import transformers.integrations.sdpa_attention as sdpa_attention
55
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
6-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
6+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, ignore_warnings
77
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
88
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
99
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
@@ -123,7 +123,7 @@ def test_causal_mask_in_scaled_dot_product_attention(self):
123123
attn_causal_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
124124
self.assertEqual(attn_causal_bias.min().item(), -float("inf"))
125125

126-
# @ignore_warnings(UserWarning)
126+
@ignore_warnings(UserWarning)
127127
def test_causal_mask_in_scaled_dot_product_attention_export(self):
128128
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
129129
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward

onnx_diagnostic/export/api.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,38 @@ def to_onnx(
1515
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1616
exporter: str = "onnx-dynamo",
1717
) -> Any:
18-
"""Common API for exporters."""
18+
"""
19+
Common API for exporters. By default, the models are optimized to use the
20+
most efficient kernels implemented in :epkg:`onnxruntime`.
21+
22+
:param mod: torch model
23+
:param args: unnamed arguments
24+
:param kwargs: named arguments
25+
:param input_names: input names for the onnx model (optional)
26+
:param target_opset: opset to target, if not specified, each converter
27+
keeps its default value
28+
:param verbose: verbosity level
29+
:param dynamic_shapes: dynamic shapes, usually a nested structure
30+
included a dictionary for each tensor
31+
:param filename: output filename
32+
:param output_names: to change the output of the onnx model
33+
:param output_dynamic_shapes: to overwrite the dynamic shapes names
34+
:param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
35+
:return: the output of the selected exporter, usually a structure including
36+
an onnx model
37+
38+
A simple example:
39+
40+
.. code-block:: python
41+
42+
to_onnx(
43+
model,
44+
kwargs=inputs,
45+
dynamic_shapes=ds,
46+
exporter=exporter,
47+
filename=filename,
48+
)
49+
"""
1950
if exporter == "custom":
2051
from experimental_experiment.torch_interpreter import to_onnx as _to_onnx
2152
from experimental_experiment.xbuilder import OptimizationOptions

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _is_torchdynamo_exporting() -> bool:
6666
return False
6767

6868

69+
patch_is_causal = _has_transformers("4.55")
6970
patch_is_initialized = _has_transformers("4.56.99")
7071

7172

@@ -1365,10 +1366,15 @@ def patched_sdpa_attention_forward(
13651366
if attention_mask is not None and attention_mask.ndim == 4:
13661367
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
13671368

1368-
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1369-
# PATCHED: remove the test query.shape[2] > 1
1370-
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1371-
is_causal = attention_mask is None and is_causal
1369+
if patch_is_causal:
1370+
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1371+
1372+
# PATCHED: remove the test query.shape[2] > 1
1373+
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1374+
# and we split the test to keep the minimum in torch.cond
1375+
is_causal = attention_mask is None and is_causal
1376+
elif is_causal is None:
1377+
is_causal = attention_mask is None
13721378

13731379
torch._check(
13741380
attention_mask is None or attention_mask.shape[3] == key.shape[2],

0 commit comments

Comments
 (0)