@@ -1906,6 +1906,7 @@ def get_fa_args(
19061906 dk = None ,
19071907 dv = None ,
19081908 ):
1909+ """Get forward/backward arguments for flash-attn v2 and v3."""
19091910 if use_flash_attn_3 :
19101911 if forward :
19111912 if qkv_format == "thd" :
@@ -1918,66 +1919,59 @@ def get_fa_args(
19181919 max_seqlen_kv ,
19191920 * [None ] * 8 , # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
19201921 ]
1921- else :
1922- return [
1923- * [None ] * 9 , # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
1924- max_seqlen_q ,
1925- max_seqlen_kv ,
1926- * [None ] * 8 , # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
1927- ]
1928- else :
1929- if qkv_format == "thd" :
1930- return [
1931- cu_seqlens_q ,
1932- cu_seqlens_kv ,
1933- None , # sequed_q
1934- None , # sequed_k
1935- max_seqlen_q ,
1936- max_seqlen_kv ,
1937- dq ,
1938- dk ,
1939- dv ,
1940- ]
1941- else :
1942- return [
1943- None , # cu_seqlens_q
1944- None , # cu_seqlens_kv
1945- None , # sequed_q
1946- None , # sequed_k
1947- max_seqlen_q ,
1948- max_seqlen_kv ,
1949- dq ,
1950- dk ,
1951- dv ,
1952- ]
1953- else :
1954- if forward :
1955- if qkv_format == "thd" :
1956- return [
1957- cu_seqlens_q ,
1958- cu_seqlens_kv ,
1959- max_seqlen_q ,
1960- max_seqlen_kv ,
1961- ]
1962- else :
1963- return []
1964- else :
1965- if qkv_format == "thd" :
1966- return [
1967- dq ,
1968- dk ,
1969- dv ,
1970- cu_seqlens_q ,
1971- cu_seqlens_kv ,
1972- max_seqlen_q ,
1973- max_seqlen_kv ,
1974- ]
1975- else :
1976- return [
1977- dq ,
1978- dk ,
1979- dv ,
1980- ]
1922+ return [
1923+ * [None ] * 9 , # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
1924+ max_seqlen_q ,
1925+ max_seqlen_kv ,
1926+ * [None ] * 8 , # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
1927+ ]
1928+ if qkv_format == "thd" :
1929+ return [
1930+ cu_seqlens_q ,
1931+ cu_seqlens_kv ,
1932+ None , # sequed_q
1933+ None , # sequed_k
1934+ max_seqlen_q ,
1935+ max_seqlen_kv ,
1936+ dq ,
1937+ dk ,
1938+ dv ,
1939+ ]
1940+ return [
1941+ None , # cu_seqlens_q
1942+ None , # cu_seqlens_kv
1943+ None , # sequed_q
1944+ None , # sequed_k
1945+ max_seqlen_q ,
1946+ max_seqlen_kv ,
1947+ dq ,
1948+ dk ,
1949+ dv ,
1950+ ]
1951+ if forward :
1952+ if qkv_format == "thd" :
1953+ return [
1954+ cu_seqlens_q ,
1955+ cu_seqlens_kv ,
1956+ max_seqlen_q ,
1957+ max_seqlen_kv ,
1958+ ]
1959+ return []
1960+ if qkv_format == "thd" :
1961+ return [
1962+ dq ,
1963+ dk ,
1964+ dv ,
1965+ cu_seqlens_q ,
1966+ cu_seqlens_kv ,
1967+ max_seqlen_q ,
1968+ max_seqlen_kv ,
1969+ ]
1970+ return [
1971+ dq ,
1972+ dk ,
1973+ dv ,
1974+ ]
19811975
19821976
19831977class AttnFuncWithCPAndKVP2P (torch .autograd .Function ):
0 commit comments