You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1678
-
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1679
-
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
1677
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # noqa: E501
1678
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # noqa: E501
1679
+
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` # noqa: E501
1680
1680
ifis_causalisNone:
1681
-
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
1682
-
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
1683
-
# is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
1681
+
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag # noqa: E501
1682
+
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns # noqa: E501
1683
+
# is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # noqa: E501
1684
1684
# NOTE: query.shape[2] == 1 or > 1 should have the same output for causal attention
0 commit comments