|
1 |
| ---- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000 |
2 |
| -+++ flash.py 2023-11-28 16:14:25.206128903 +0000 |
3 |
| -@@ -31,39 +31,39 @@ |
| 1 | +--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000 |
| 2 | ++++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000 |
| 3 | +@@ -36,44 +36,44 @@ |
4 | 4 |
|
5 | 5 | FLASH_VERSION = "0.0.0"
|
6 | 6 | try:
|
|
15 | 15 | - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
16 | 16 | -
|
17 | 17 | - FLASH_VERSION = flash_attn.__version__
|
18 |
| -- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) |
19 |
| -- if flash_ver_parsed < (2, 3): |
20 |
| -- raise ImportError("Requires 2.3 for sliding window support") |
| 18 | +- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) |
| 19 | +- if ( |
| 20 | +- flash_ver_parsed != (2, 3, 6) |
| 21 | +- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" |
| 22 | +- ): |
| 23 | +- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") |
21 | 24 | + #try:
|
22 | 25 | + # from ... import _C_flashattention # type: ignore[attr-defined]
|
23 | 26 | + # from ..._cpp_lib import _build_metadata
|
|
29 | 32 | + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
30 | 33 | +
|
31 | 34 | + FLASH_VERSION = flash_attn.__version__
|
32 |
| -+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) |
33 |
| -+ # if flash_ver_parsed < (2, 3): |
34 |
| -+ # raise ImportError("Requires 2.3 for sliding window support") |
| 35 | ++ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) |
| 36 | ++ # if ( |
| 37 | ++ # flash_ver_parsed != (2, 3, 6) |
| 38 | ++ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" |
| 39 | ++ # ): |
| 40 | ++ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") |
35 | 41 |
|
36 | 42 | # create library so that flash-attn goes through the PyTorch Dispatcher
|
37 | 43 | - _flash_lib = torch.library.Library("xformers_flash", "DEF")
|
38 |
| -+ #_flash_lib = torch.library.Library("xformers_flash", "DEF") |
39 |
| - |
| 44 | +- |
40 | 45 | - _flash_lib.define(
|
41 | 46 | - "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
42 |
| -- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " |
| 47 | +- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " |
43 | 48 | - "int max_seqlen_q, int max_seqlen_k, "
|
44 | 49 | - "float p, float softmax_scale, "
|
45 |
| -- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" |
| 50 | +- "bool is_causal, int window_left, " |
| 51 | +- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" |
46 | 52 | - )
|
47 |
| -- |
| 53 | ++ #_flash_lib = torch.library.Library("xformers_flash", "DEF") |
| 54 | + |
48 | 55 | - _flash_lib.define(
|
49 | 56 | - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
50 | 57 | - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
51 | 58 | - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
52 | 59 | - "int max_seqlen_q, int max_seqlen_k, "
|
53 |
| -- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" |
| 60 | +- "float p, float softmax_scale, bool is_causal, " |
| 61 | +- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" |
54 | 62 | - )
|
55 | 63 | + #_flash_lib.define(
|
56 | 64 | + # "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
57 |
| -+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " |
| 65 | ++ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " |
58 | 66 | + # "int max_seqlen_q, int max_seqlen_k, "
|
59 | 67 | + # "float p, float softmax_scale, "
|
60 |
| -+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" |
| 68 | ++ # "bool is_causal, int window_left, " |
| 69 | ++ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" |
61 | 70 | + #)
|
62 | 71 | +
|
63 | 72 | + #_flash_lib.define(
|
64 | 73 | + # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
65 | 74 | + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
66 | 75 | + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
67 | 76 | + # "int max_seqlen_q, int max_seqlen_k, "
|
68 |
| -+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" |
| 77 | ++ # "float p, float softmax_scale, bool is_causal, " |
| 78 | ++ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" |
69 | 79 | + #)
|
70 | 80 |
|
71 | 81 | def _flash_fwd(
|
72 | 82 | query,
|
73 |
| -@@ -98,8 +98,8 @@ |
| 83 | +@@ -111,8 +111,8 @@ |
74 | 84 | p,
|
75 | 85 | softmax_scale,
|
76 | 86 | is_causal,
|
77 |
| -- window_size - 1, # window_size_left |
78 |
| -- -1, # window_size_right |
79 |
| -+ # window_size - 1, # window_size_left |
80 |
| -+ # -1, # window_size_right |
| 87 | +- window_left, # window_size_left |
| 88 | +- window_right, # window_size_right |
| 89 | ++ # window_left, # window_size_left |
| 90 | ++ # window_right, # window_size_right |
81 | 91 | return_softmax,
|
82 | 92 | None, # rng
|
83 | 93 | )
|
84 |
| -@@ -127,8 +127,8 @@ |
| 94 | +@@ -134,15 +134,15 @@ |
| 95 | + out, |
| 96 | + cu_seq_lens_q, |
| 97 | + cu_seq_lens_k, |
| 98 | +- seqused_k, |
| 99 | ++ # seqused_k, |
| 100 | + max_seq_len_q, |
| 101 | + max_seq_len_k, |
| 102 | + p, |
85 | 103 | softmax_scale,
|
86 | 104 | False,
|
87 | 105 | is_causal,
|
88 |
| -- window_size - 1, # window_size_left |
89 |
| -- -1, # window_size_right |
90 |
| -+ # window_size - 1, # window_size_left |
91 |
| -+ # -1, # window_size_right |
| 106 | +- window_left, |
| 107 | +- window_right, |
| 108 | ++ # window_left, |
| 109 | ++ # window_right, |
92 | 110 | return_softmax,
|
93 | 111 | None,
|
94 | 112 | )
|
95 |
| -@@ -169,8 +169,8 @@ |
| 113 | +@@ -184,8 +184,8 @@ |
96 | 114 | p,
|
97 | 115 | softmax_scale,
|
98 | 116 | is_causal,
|
99 |
| -- window_size - 1, # window_size_left |
100 |
| -- -1, # window_size_right |
101 |
| -+ # window_size - 1, # window_size_left |
102 |
| -+ # -1, # window_size_right |
| 117 | +- window_left, |
| 118 | +- window_right, |
| 119 | ++ # window_left, |
| 120 | ++ # window_right, |
103 | 121 | None,
|
104 | 122 | rng_state,
|
105 | 123 | )
|
106 |
| -@@ -193,15 +193,15 @@ |
| 124 | +@@ -208,15 +208,15 @@ |
107 | 125 | softmax_scale,
|
108 | 126 | False, # zero_tensors
|
109 | 127 | is_causal,
|
110 |
| -- window_size - 1, # window_size_left |
111 |
| -- -1, # window_size_right |
112 |
| -+ # window_size - 1, # window_size_left |
113 |
| -+ # -1, # window_size_right |
| 128 | +- window_left, |
| 129 | +- window_right, |
| 130 | ++ # window_left, |
| 131 | ++ # window_right, |
114 | 132 | None,
|
115 | 133 | rng_state,
|
116 | 134 | )
|
|
123 | 141 | except ImportError:
|
124 | 142 | pass
|
125 | 143 |
|
126 |
| -@@ -348,7 +348,7 @@ |
| 144 | +@@ -400,7 +400,7 @@ |
127 | 145 | implementation.
|
128 | 146 | """
|
129 | 147 |
|
|
0 commit comments