Skip to content

Commit cb27ace

Browse files
committed
fix issues
1 parent 61b19f7 commit cb27ace

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

_unittests/ut_tasks/test_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
has_transformers,
77
requires_transformers,
88
)
9-
from onnx_diagnostic.helpers.torch_helper import to_any
9+
from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy
1010
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1111
from onnx_diagnostic.torch_export_patches import torch_export_patches
1212
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,42 @@ def wrapper(self, x, position_ids):
864864
return wrapper
865865

866866

867+
def patched_model_bart_eager_attention_forward(
868+
module: torch.nn.Module,
869+
query: torch.Tensor,
870+
key: torch.Tensor,
871+
value: torch.Tensor,
872+
attention_mask: Optional[torch.Tensor],
873+
scaling: Optional[float] = None,
874+
dropout: float = 0.0,
875+
head_mask: Optional[torch.Tensor] = None,
876+
**kwargs,
877+
):
878+
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
879+
if scaling is None:
880+
scaling = query.size(-1) ** -0.5
881+
882+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
883+
if attention_mask is not None:
884+
# The two following lines were added.
885+
if attention_mask is not None and attention_mask.ndim == 4:
886+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
887+
attn_weights = attn_weights + attention_mask
888+
889+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
890+
891+
if head_mask is not None:
892+
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
893+
894+
attn_weights = torch.nn.functional.dropout(
895+
attn_weights, p=dropout, training=module.training
896+
)
897+
attn_output = torch.matmul(attn_weights, value)
898+
attn_output = attn_output.transpose(1, 2).contiguous()
899+
900+
return attn_output, attn_weights
901+
902+
867903
class common_RotaryEmbedding(torch.nn.Module):
868904
@torch.no_grad()
869905
@patched_dynamic_rope_update

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ def get_untrained_model_with_inputs(
144144
f"[get_untrained_model_with_inputs] config._attn_implementation="
145145
f"{config._attn_implementation!r}" # type: ignore[union-attr]
146146
)
147+
elif verbose:
148+
print(
149+
f"[get_untrained_model_with_inputs] default config._attn_implementation="
150+
f"{config._attn_implementation!r}" # type: ignore[union-attr]
151+
)
147152

148153
if type(config) is dict and "_diffusers_version" in config:
149154
import diffusers

0 commit comments

Comments
 (0)