@@ -1893,6 +1893,112 @@ def __call__(
18931893 return hidden_states
18941894
18951895
1896+ class FluxAttnProcessor2_0_NPU :
1897+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1898+
1899+ def __init__ (self ):
1900+ if not hasattr (F , "scaled_dot_product_attention" ):
1901+ raise ImportError (
1902+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
1903+ )
1904+
1905+ def __call__ (
1906+ self ,
1907+ attn : Attention ,
1908+ hidden_states : torch .FloatTensor ,
1909+ encoder_hidden_states : torch .FloatTensor = None ,
1910+ attention_mask : Optional [torch .FloatTensor ] = None ,
1911+ image_rotary_emb : Optional [torch .Tensor ] = None ,
1912+ ) -> torch .FloatTensor :
1913+ batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
1914+
1915+ # `sample` projections.
1916+ query = attn .to_q (hidden_states )
1917+ key = attn .to_k (hidden_states )
1918+ value = attn .to_v (hidden_states )
1919+
1920+ inner_dim = key .shape [- 1 ]
1921+ head_dim = inner_dim // attn .heads
1922+
1923+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1924+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1925+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1926+
1927+ if attn .norm_q is not None :
1928+ query = attn .norm_q (query )
1929+ if attn .norm_k is not None :
1930+ key = attn .norm_k (key )
1931+
1932+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1933+ if encoder_hidden_states is not None :
1934+ # `context` projections.
1935+ encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
1936+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
1937+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
1938+
1939+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
1940+ batch_size , - 1 , attn .heads , head_dim
1941+ ).transpose (1 , 2 )
1942+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
1943+ batch_size , - 1 , attn .heads , head_dim
1944+ ).transpose (1 , 2 )
1945+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
1946+ batch_size , - 1 , attn .heads , head_dim
1947+ ).transpose (1 , 2 )
1948+
1949+ if attn .norm_added_q is not None :
1950+ encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
1951+ if attn .norm_added_k is not None :
1952+ encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
1953+
1954+ # attention
1955+ query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
1956+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
1957+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
1958+
1959+ if image_rotary_emb is not None :
1960+ from .embeddings import apply_rotary_emb
1961+
1962+ query = apply_rotary_emb (query , image_rotary_emb )
1963+ key = apply_rotary_emb (key , image_rotary_emb )
1964+
1965+ if query .dtype in (torch .float16 , torch .bfloat16 ):
1966+ hidden_states = torch_npu .npu_fusion_attention (
1967+ query ,
1968+ key ,
1969+ value ,
1970+ attn .heads ,
1971+ input_layout = "BNSD" ,
1972+ pse = None ,
1973+ scale = 1.0 / math .sqrt (query .shape [- 1 ]),
1974+ pre_tockens = 65536 ,
1975+ next_tockens = 65536 ,
1976+ keep_prob = 1.0 ,
1977+ sync = False ,
1978+ inner_precise = 0 ,
1979+ )[0 ]
1980+ else :
1981+ hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
1982+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1983+ hidden_states = hidden_states .to (query .dtype )
1984+
1985+ if encoder_hidden_states is not None :
1986+ encoder_hidden_states , hidden_states = (
1987+ hidden_states [:, : encoder_hidden_states .shape [1 ]],
1988+ hidden_states [:, encoder_hidden_states .shape [1 ] :],
1989+ )
1990+
1991+ # linear proj
1992+ hidden_states = attn .to_out [0 ](hidden_states )
1993+ # dropout
1994+ hidden_states = attn .to_out [1 ](hidden_states )
1995+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
1996+
1997+ return hidden_states , encoder_hidden_states
1998+ else :
1999+ return hidden_states
2000+
2001+
18962002class FusedFluxAttnProcessor2_0 :
18972003 """Attention processor used typically in processing the SD3-like self-attention projections."""
18982004
@@ -1987,6 +2093,117 @@ def __call__(
19872093 return hidden_states
19882094
19892095
2096+ class FusedFluxAttnProcessor2_0_NPU :
2097+ """Attention processor used typically in processing the SD3-like self-attention projections."""
2098+
2099+ def __init__ (self ):
2100+ if not hasattr (F , "scaled_dot_product_attention" ):
2101+ raise ImportError (
2102+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
2103+ )
2104+
2105+ def __call__ (
2106+ self ,
2107+ attn : Attention ,
2108+ hidden_states : torch .FloatTensor ,
2109+ encoder_hidden_states : torch .FloatTensor = None ,
2110+ attention_mask : Optional [torch .FloatTensor ] = None ,
2111+ image_rotary_emb : Optional [torch .Tensor ] = None ,
2112+ ) -> torch .FloatTensor :
2113+ batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
2114+
2115+ # `sample` projections.
2116+ qkv = attn .to_qkv (hidden_states )
2117+ split_size = qkv .shape [- 1 ] // 3
2118+ query , key , value = torch .split (qkv , split_size , dim = - 1 )
2119+
2120+ inner_dim = key .shape [- 1 ]
2121+ head_dim = inner_dim // attn .heads
2122+
2123+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2124+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2125+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2126+
2127+ if attn .norm_q is not None :
2128+ query = attn .norm_q (query )
2129+ if attn .norm_k is not None :
2130+ key = attn .norm_k (key )
2131+
2132+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2133+ # `context` projections.
2134+ if encoder_hidden_states is not None :
2135+ encoder_qkv = attn .to_added_qkv (encoder_hidden_states )
2136+ split_size = encoder_qkv .shape [- 1 ] // 3
2137+ (
2138+ encoder_hidden_states_query_proj ,
2139+ encoder_hidden_states_key_proj ,
2140+ encoder_hidden_states_value_proj ,
2141+ ) = torch .split (encoder_qkv , split_size , dim = - 1 )
2142+
2143+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
2144+ batch_size , - 1 , attn .heads , head_dim
2145+ ).transpose (1 , 2 )
2146+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
2147+ batch_size , - 1 , attn .heads , head_dim
2148+ ).transpose (1 , 2 )
2149+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
2150+ batch_size , - 1 , attn .heads , head_dim
2151+ ).transpose (1 , 2 )
2152+
2153+ if attn .norm_added_q is not None :
2154+ encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
2155+ if attn .norm_added_k is not None :
2156+ encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
2157+
2158+ # attention
2159+ query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
2160+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
2161+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
2162+
2163+ if image_rotary_emb is not None :
2164+ from .embeddings import apply_rotary_emb
2165+
2166+ query = apply_rotary_emb (query , image_rotary_emb )
2167+ key = apply_rotary_emb (key , image_rotary_emb )
2168+
2169+ if query .dtype in (torch .float16 , torch .bfloat16 ):
2170+ hidden_states = torch_npu .npu_fusion_attention (
2171+ query ,
2172+ key ,
2173+ value ,
2174+ attn .heads ,
2175+ input_layout = "BNSD" ,
2176+ pse = None ,
2177+ scale = 1.0 / math .sqrt (query .shape [- 1 ]),
2178+ pre_tockens = 65536 ,
2179+ next_tockens = 65536 ,
2180+ keep_prob = 1.0 ,
2181+ sync = False ,
2182+ inner_precise = 0 ,
2183+ )[0 ]
2184+ else :
2185+ hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
2186+
2187+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
2188+ hidden_states = hidden_states .to (query .dtype )
2189+
2190+ if encoder_hidden_states is not None :
2191+ encoder_hidden_states , hidden_states = (
2192+ hidden_states [:, : encoder_hidden_states .shape [1 ]],
2193+ hidden_states [:, encoder_hidden_states .shape [1 ] :],
2194+ )
2195+
2196+ # linear proj
2197+ hidden_states = attn .to_out [0 ](hidden_states )
2198+ # dropout
2199+ hidden_states = attn .to_out [1 ](hidden_states )
2200+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
2201+
2202+ return hidden_states , encoder_hidden_states
2203+ else :
2204+ return hidden_states
2205+
2206+
19902207class CogVideoXAttnProcessor2_0 :
19912208 r"""
19922209 Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
0 commit comments