Skip to content

Commit b3cadb8

Browse files
draft of cogview3plus
1 parent 61d3764 commit b3cadb8

File tree

11 files changed

+1452
-1
lines changed

11 files changed

+1452
-1
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import argparse
2+
from typing import Any, Dict
3+
4+
import torch
5+
from transformers import T5EncoderModel, T5Tokenizer
6+
7+
from diffusers import (
8+
CogView3PlusTransformer2DModel,
9+
CogView3PlusPipeline,
10+
)
11+
12+
TRANSFORMER_KEYS_RENAME_DICT = {
13+
"transformer": "transformer_blocks",
14+
"attention": "attn1",
15+
"mlp": "ff.net",
16+
"dense_h_to_4h": "0.proj",
17+
"dense_4h_to_h": "2",
18+
".layers": "",
19+
"dense": "to_out.0",
20+
"patch_embed": "norm1.norm",
21+
"post_attn1_layernorm": "norm2.norm",
22+
"mixins.patch_embed": "patch_embed",
23+
"mixins.final_layer.adaln": "norm_out",
24+
"mixins.final_layer.linear": "proj_out",
25+
}

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
"AutoencoderOobleck",
8585
"AutoencoderTiny",
8686
"CogVideoXTransformer3DModel",
87+
"CogView3PlusTransformer2DModel",
8788
"ConsistencyDecoderVAE",
8889
"ControlNetModel",
8990
"ControlNetXSAdapter",
@@ -258,6 +259,7 @@
258259
"CogVideoXImageToVideoPipeline",
259260
"CogVideoXPipeline",
260261
"CogVideoXVideoToVideoPipeline",
262+
"CogView3PlusPipeline",
261263
"CycleDiffusionPipeline",
262264
"FluxControlNetImg2ImgPipeline",
263265
"FluxControlNetInpaintPipeline",
@@ -558,6 +560,7 @@
558560
AutoencoderOobleck,
559561
AutoencoderTiny,
560562
CogVideoXTransformer3DModel,
563+
CogView3PlusTransformer2DModel,
561564
ConsistencyDecoderVAE,
562565
ControlNetModel,
563566
ControlNetXSAdapter,
@@ -710,6 +713,7 @@
710713
CogVideoXImageToVideoPipeline,
711714
CogVideoXPipeline,
712715
CogVideoXVideoToVideoPipeline,
716+
CogView3PlusPipeline,
713717
CycleDiffusionPipeline,
714718
FluxControlNetImg2ImgPipeline,
715719
FluxControlNetInpaintPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_import_structure["modeling_utils"] = ["ModelMixin"]
4545
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
4646
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
47+
_import_structure["transformers.transformer_cogview3dplus"] = ["CogView3PlusTransformer2DModel"]
4748
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
4849
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
4950
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
@@ -98,6 +99,7 @@
9899
from .transformers import (
99100
AuraFlowTransformer2DModel,
100101
CogVideoXTransformer3DModel,
102+
CogView3PlusTransformer2DModel,
101103
DiTTransformer2DModel,
102104
DualTransformer2DModel,
103105
FluxTransformer2DModel,

src/diffusers/models/embeddings.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,114 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
714714
return freqs_cos, freqs_sin
715715

716716

717+
class CogView3PlusPosEmbed(nn.Module):
718+
def __init__(
719+
self,
720+
max_height: int = 128,
721+
max_width: int = 128,
722+
hidden_size: int = 2560,
723+
text_length: int = 0,
724+
block_size: int = 16,
725+
):
726+
super().__init__()
727+
self.max_height = max_height
728+
self.max_width = max_width
729+
self.hidden_size = hidden_size
730+
self.text_length = text_length
731+
self.block_size = block_size
732+
733+
# Initialize the positional embedding as a non-trainable parameter
734+
self.image_pos_embedding = nn.Parameter(
735+
torch.zeros(self.max_height, self.max_width, hidden_size), requires_grad=False
736+
)
737+
# Reinitialize the positional embedding using a sin-cos function
738+
self.reinit()
739+
740+
def forward(self, target_size: List[int]) -> torch.Tensor:
741+
ret = []
742+
for h, w in target_size:
743+
# Scale height and width according to the block size
744+
h, w = h // self.block_size, w // self.block_size
745+
746+
# Reshape the image positional embedding for the target size
747+
image_pos_embed = self.image_pos_embedding[:h, :w].reshape(h * w, -1)
748+
749+
# Combine the text positional embedding and image positional embedding
750+
pos_embed = torch.cat(
751+
[
752+
torch.zeros(
753+
(self.text_length, self.hidden_size),
754+
dtype=image_pos_embed.dtype,
755+
device=image_pos_embed.device,
756+
),
757+
image_pos_embed,
758+
],
759+
dim=0,
760+
)
761+
762+
ret.append(pos_embed[None, ...]) # Add a batch dimension
763+
764+
return torch.cat(ret, dim=0) # Concatenate along the batch dimension
765+
766+
def reinit(self):
767+
# Initialize the positional embedding using a 2D sin-cos function
768+
pos_embed_np = self.get_2d_sincos_pos_embed(self.hidden_size, self.max_height, self.max_width)
769+
self.image_pos_embedding.data.copy_(torch.from_numpy(pos_embed_np).float())
770+
771+
772+
class CogView3PlusImagePatchEmbedding(nn.Module):
773+
def __init__(
774+
self,
775+
in_channels: int = 128,
776+
hidden_size: int = 128,
777+
patch_size: int = 2,
778+
text_hidden_size: int = 4096,
779+
):
780+
super().__init__()
781+
self.in_channels = in_channels
782+
self.hidden_size = hidden_size
783+
self.patch_size = patch_size
784+
self.text_hidden_size = text_hidden_size
785+
786+
# Linear projection for image patches
787+
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
788+
789+
# Linear projection for text embeddings
790+
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
791+
792+
def forward(self, images: torch.Tensor, encoder_outputs: torch.Tensor = None) -> torch.Tensor:
793+
# Rearrange the images
794+
# patches_images = rearrange(images, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=self.patch_size, p2=self.patch_size)
795+
796+
b, c, h, w = images.shape
797+
p1, p2 = self.patch_size, self.patch_size
798+
assert h % p1 == 0 and w % p2 == 0, "Height and width must be divisible by patch size"
799+
800+
images = images.view(b, c, h // p1, p1, w // p2, p2)
801+
patches_images = images.permute(0, 2, 4, 1, 3, 5).contiguous()
802+
patches_images = patches_images.view(b, (h // p1) * (w // p2), c * p1 * p2)
803+
804+
# Project the patches
805+
image_emb = self.proj(patches_images)
806+
807+
# If text embeddings are provided, project and concatenate them
808+
if self.text_hidden_size is not None and encoder_outputs is not None:
809+
text_emb = self.text_proj(encoder_outputs)
810+
emb = torch.cat([text_emb, image_emb], dim=1)
811+
else:
812+
emb = image_emb
813+
814+
return emb
815+
816+
def reinit(self, parent_model=None):
817+
# Reinitialize the projection weights
818+
nn.init.xavier_uniform_(self.proj.weight)
819+
nn.init.constant_(self.proj.bias, 0)
820+
if self.text_hidden_size is not None:
821+
nn.init.xavier_uniform_(self.text_proj.weight)
822+
nn.init.constant_(self.text_proj.bias, 0)
823+
824+
717825
class TimestepEmbedding(nn.Module):
718826
def __init__(
719827
self,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
from .transformer_2d import Transformer2DModel
1717
from .transformer_flux import FluxTransformer2DModel
1818
from .transformer_sd3 import SD3Transformer2DModel
19+
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
1920
from .transformer_temporal import TransformerTemporalModel

0 commit comments

Comments
 (0)