Skip to content

Commit 3bd6d30

Browse files
committed
[WIP][cogview4]: implement CogView4 attention processor
Add CogView4AttnProcessor class for implementing scaled dot-product attention with rotary embeddings for the CogVideoX model. This processor concatenates encoder and hidden states, applies QKV projections and RoPE, but does not include spatial normalization. TODO: - Fix incorrect QKV projection weights - Resolve ~25% error in RoPE implementation compared to Megatron
1 parent 310da29 commit 3bd6d30

File tree

2 files changed

+135
-3
lines changed

2 files changed

+135
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,6 +2802,105 @@ def __call__(
28022802
return hidden_states
28032803

28042804

2805+
class CogView4AttnProcessor:
2806+
"""
2807+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
2808+
query and key vectors, but does not include spatial normalization.
2809+
"""
2810+
2811+
def __init__(self):
2812+
if not hasattr(F, "scaled_dot_product_attention"):
2813+
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2814+
2815+
def __call__(
2816+
self,
2817+
attn: Attention,
2818+
hidden_states: torch.Tensor,
2819+
encoder_hidden_states: torch.Tensor,
2820+
attention_mask: Optional[torch.Tensor] = None,
2821+
image_rotary_emb: Optional[torch.Tensor] = None,
2822+
) -> torch.Tensor:
2823+
text_seq_length = encoder_hidden_states.size(1)
2824+
2825+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
2826+
2827+
batch_size, sequence_length, _ = hidden_states.shape
2828+
2829+
if attention_mask is not None:
2830+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2831+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2832+
2833+
query = attn.to_q(hidden_states)
2834+
key = attn.to_k(hidden_states)
2835+
value = attn.to_v(hidden_states)
2836+
2837+
inner_dim = key.shape[-1]
2838+
head_dim = inner_dim // attn.heads
2839+
2840+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2841+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2842+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2843+
2844+
###############################################3
2845+
# TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
2846+
linear_qkv_weight = torch.load("/home/lhy/code/cogview/linear_qkv_weight.pt")
2847+
linear_qkv_bias = torch.load("/home/lhy/code/cogview/linear_qkv_bias.pt")
2848+
2849+
qkv = torch.matmul(hidden_states, linear_qkv_weight.T) + linear_qkv_bias
2850+
qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
2851+
query, key, value = qkv.chunk(3, dim=-1)
2852+
2853+
2854+
# TODO: 校验rope是否apply正确(目前有25%的误差)
2855+
###############################################3
2856+
2857+
if attn.norm_q is not None:
2858+
query = attn.norm_q(query)
2859+
if attn.norm_k is not None:
2860+
key = attn.norm_k(key)
2861+
2862+
query = query.transpose(1, 2)
2863+
key = key.transpose(1, 2)
2864+
value = value.transpose(1, 2)
2865+
2866+
# Apply RoPE if needed
2867+
if image_rotary_emb is not None:
2868+
from .embeddings import apply_rotary_emb_megatron
2869+
2870+
query[:, :, text_seq_length:, :] = apply_rotary_emb_megatron(
2871+
query[:, :, text_seq_length:, :], image_rotary_emb
2872+
)
2873+
key[:, :, text_seq_length:, :] = apply_rotary_emb_megatron(
2874+
key[:, :, text_seq_length:, :], image_rotary_emb
2875+
)
2876+
2877+
##########################################
2878+
query = torch.load("/home/lhy/code/cogview/query_after_rope.pt")
2879+
key = torch.load("/home/lhy/code/cogview/key_after_rope.pt")
2880+
value = torch.load("/home/lhy/code/cogview/value_after_rope.pt")
2881+
query = query[None, :16+4096, ...]
2882+
key = key[None, :16+4096, ...]
2883+
value = value[None, :16+4096, ...]
2884+
query = query.transpose(1, 2)
2885+
key = key.transpose(1, 2)
2886+
value = value.transpose(1, 2)
2887+
##########################################
2888+
2889+
hidden_states = F.scaled_dot_product_attention(
2890+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2891+
)
2892+
2893+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2894+
2895+
# linear proj
2896+
hidden_states = attn.to_out[0](hidden_states)
2897+
2898+
encoder_hidden_states, hidden_states = hidden_states.split(
2899+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
2900+
)
2901+
return hidden_states, encoder_hidden_states
2902+
2903+
28052904
class CogVideoXAttnProcessor2_0:
28062905
r"""
28072906
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -2824,9 +2923,7 @@ def __call__(
28242923

28252924
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
28262925

2827-
batch_size, sequence_length, _ = (
2828-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2829-
)
2926+
batch_size, sequence_length, _ = hidden_states.shape
28302927

28312928
if attention_mask is not None:
28322929
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
@@ -6174,6 +6271,7 @@ def __call__(
61746271
FusedFluxAttnProcessor2_0,
61756272
FusedFluxAttnProcessor2_0_NPU,
61766273
CogVideoXAttnProcessor2_0,
6274+
CogView4AttnProcessor,
61776275
FusedCogVideoXAttnProcessor2_0,
61786276
XFormersAttnAddedKVProcessor,
61796277
XFormersAttnProcessor,

src/diffusers/models/embeddings.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,40 @@ def apply_1d_rope(tokens, pos, cos, sin):
12831283
x = torch.cat([t, h, w], dim=-1)
12841284
return x
12851285

1286+
def apply_rotary_emb_megatron(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
1287+
"""Apply rotary position embeddings to input tensor.
1288+
1289+
Args:
1290+
x: Input tensor of shape [seq_len, batch_size, n_heads, head_dim]
1291+
freqs: Frequency tensor of shape [seq_len, 1, 1, head_dim//2]
1292+
1293+
Returns:
1294+
Tensor with rotary position embeddings applied
1295+
"""
1296+
batch_size, n_heads, seq_len, rot_dim = x.shape
1297+
1298+
# Reshape x to have rot_dim as the last dimension
1299+
x_rot, x_pass = x.chunk(2, dim=-1)
1300+
1301+
# Apply rotary embeddings
1302+
# First calculate cos and sin
1303+
cos, sin = freqs.chunk(2, dim=-1)
1304+
cos, sin = torch.cos(cos), torch.sin(sin)
1305+
1306+
# Rotate x_rot
1307+
x_rot_cos = x_rot * cos
1308+
# Create rotated version of x_rot by shifting rot_dim/2 positions
1309+
x_rot_shifted = torch.cat([-x_rot[..., rot_dim//2:], x_rot[..., :rot_dim//2]], dim=-1)
1310+
x_rot_sin = x_rot_shifted * sin
1311+
1312+
# Combine
1313+
x_rot = x_rot_cos + x_rot_sin
1314+
1315+
# Concatenate back with x_pass
1316+
x_out = torch.cat([x_rot, x_pass], dim=-1)
1317+
1318+
return x_out
1319+
12861320

12871321
class FluxPosEmbed(nn.Module):
12881322
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11

0 commit comments

Comments
 (0)