Skip to content

Commit a7179a2

Browse files
draft patch(not work)
1 parent 56ceaa6 commit a7179a2

File tree

3 files changed

+100
-9
lines changed

3 files changed

+100
-9
lines changed

src/diffusers/models/embeddings.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,75 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
812812

813813
return (hidden_states + pos_embed).to(hidden_states.dtype)
814814

815+
class CogView4PatchEmbed(nn.Module):
816+
def __init__(
817+
self,
818+
in_channels: int = 16,
819+
hidden_size: int = 2560,
820+
patch_size: int = 2,
821+
text_hidden_size: int = 4096,
822+
pos_embed_max_size: int = 128,
823+
):
824+
super().__init__()
825+
self.in_channels = in_channels
826+
self.hidden_size = hidden_size
827+
self.patch_size = patch_size
828+
self.text_hidden_size = text_hidden_size
829+
self.pos_embed_max_size = pos_embed_max_size
830+
# Linear projection for image patches
831+
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
832+
833+
# Linear projection for text embeddings
834+
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
835+
#TODO:这里需要改成RotaryEmbed
836+
pos_embed = get_2d_sincos_pos_embed(
837+
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
838+
)
839+
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
840+
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
841+
842+
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
843+
batch_size, channel, height, width = hidden_states.shape
844+
845+
if height % self.patch_size != 0 or width % self.patch_size != 0:
846+
raise ValueError("Height and width must be divisible by patch size")
847+
848+
height = height // self.patch_size
849+
width = width // self.patch_size
850+
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
851+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
852+
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
853+
854+
# Project the patches
855+
hidden_states = self.proj(hidden_states)
856+
prompt_encoder_hidden_states = []
857+
negative_prompt_encoder_hidden_states = []
858+
859+
for i in range(0, batch_size, 2):
860+
prompt_embeds = encoder_hidden_states[i, :, :] # [seq_len, hidden_size]
861+
negative_embeds = encoder_hidden_states[i + 1, :, :] # [seq_len, hidden_size]
862+
mask = negative_embeds.abs().sum(dim=-1) > 0
863+
seq_len_neg = mask.sum().item() # 非零部分的数量
864+
negative_embeds_valid = negative_embeds[:seq_len_neg, :] # [seq_len_neg, hidden_size]
865+
prompt_encoder_hidden_states.append(prompt_embeds)
866+
negative_prompt_encoder_hidden_states.append(negative_embeds_valid)
867+
prompt_encoder_hidden_states = torch.stack(prompt_encoder_hidden_states, dim=0)
868+
negative_prompt_encoder_hidden_states = torch.stack(negative_prompt_encoder_hidden_states, dim=0)
869+
prompt_text_length = prompt_encoder_hidden_states.shape[1]
870+
negative_prompt_text_length = negative_prompt_encoder_hidden_states.shape[1]
871+
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
872+
prompt_text_pos_embed = torch.zeros(
873+
(prompt_text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
874+
)
875+
negative_prompt_text_pos_embed = torch.zeros(
876+
(negative_prompt_text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
877+
)
878+
prompt_pos_embed = torch.cat([prompt_text_pos_embed, image_pos_embed], dim=0)[None, ...]
879+
negative_prompt_pos_embed = torch.cat([negative_prompt_text_pos_embed, image_pos_embed], dim=0)[None, ...]
880+
# TODO: 拼接哼一个完整的 pos_embed 以及拼接 Rope Embed
881+
pos_embed = torch.cat([prompt_pos_embed, negative_prompt_pos_embed], dim=0)
882+
hidden_states = hidden_states + pos_embed.to(hidden_states.dtype)
883+
return hidden_states
815884

816885
def get_3d_rotary_pos_embed(
817886
embed_dim,

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...models.modeling_utils import ModelMixin
2929
from ...models.normalization import AdaLayerNormContinuous
3030
from ...utils import is_torch_version, logging
31-
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
31+
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed, CogView4PatchEmbed
3232
from ..modeling_outputs import Transformer2DModelOutput
3333
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
3434

@@ -166,7 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
166166
"""
167167

168168
_supports_gradient_checkpointing = True
169-
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
169+
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed", "CogView4PlusPatchEmbed"]
170170

171171
@register_to_config
172172
def __init__(
@@ -191,7 +191,15 @@ def __init__(
191191
# Each of these are sincos embeddings of shape 2 * condition_dim
192192
self.pooled_projection_dim = 3 * 2 * condition_dim
193193

194-
self.patch_embed = CogView3PlusPatchEmbed(
194+
# self.patch_embed = CogView3PlusPatchEmbed(
195+
# in_channels=in_channels,
196+
# hidden_size=self.inner_dim,
197+
# patch_size=patch_size,
198+
# text_hidden_size=text_embed_dim,
199+
# pos_embed_max_size=pos_embed_max_size,
200+
# )
201+
# TODO: 兼容性适配
202+
self.patch_embed = CogView4PatchEmbed(
195203
in_channels=in_channels,
196204
hidden_size=self.inner_dim,
197205
patch_size=patch_size,

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,24 @@ def encode_prompt(
311311
device=device,
312312
dtype=dtype,
313313
)
314+
315+
#TODO: 先pad 0 ,后续再处理不同长度的问题
316+
seq_len_prompt = prompt_embeds.shape[1]
317+
seq_len_neg = negative_prompt_embeds.shape[1]
318+
if seq_len_neg < seq_len_prompt:
319+
# 创建一个新的张量,大小为 [batch_size, seq_len_prompt, hidden_size]
320+
batch_size = negative_prompt_embeds.shape[0]
321+
hidden_size = negative_prompt_embeds.shape[2]
322+
# 填充后的张量
323+
padded_negative_prompt_embeds = torch.zeros(
324+
batch_size,
325+
seq_len_prompt,
326+
hidden_size,
327+
dtype=negative_prompt_embeds.dtype,
328+
device=negative_prompt_embeds.device
329+
)
330+
padded_negative_prompt_embeds[:, :seq_len_neg, :] = negative_prompt_embeds
331+
negative_prompt_embeds = padded_negative_prompt_embeds
314332
return prompt_embeds, negative_prompt_embeds
315333

316334
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -582,7 +600,7 @@ def __call__(
582600
device=device,
583601
)
584602
if self.do_classifier_free_guidance:
585-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=1)
603+
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
586604

587605
# 4. Prepare timesteps
588606
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -594,7 +612,6 @@ def __call__(
594612
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
595613

596614
self.sigmas = time_shift(mu, 1.0, sigmas).to(torch.long).to("cpu") # This is for noisy control of cogview4
597-
598615
self._num_timesteps = len(timesteps)
599616

600617
# 5. Prepare latents.
@@ -635,11 +652,8 @@ def __call__(
635652
for i, t in enumerate(timesteps):
636653
if self.interrupt:
637654
continue
638-
639-
# latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
640-
latent_model_input = latents # For CogView4 concat the text embed and only use prompt
655+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
641656
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
642-
643657
# Use sigma instead of timestep directly
644658
sigma = self.sigmas[i] # Get the corresponding sigma value
645659
timestep = sigma.expand(latent_model_input.shape[0]).to(device) # Use sigma to scale the timestep

0 commit comments

Comments
 (0)