Skip to content

Commit cefca0f

Browse files
right way
1 parent 21439e2 commit cefca0f

File tree

5 files changed

+302
-309
lines changed

5 files changed

+302
-309
lines changed

scripts/convert_cogview3_to_diffusers.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,11 @@ def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
3737
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
3838
layer_id, _, weight_or_bias = key.split(".")[-3:]
3939

40-
weights_or_biases = state_dict[key].chunk(12, dim=0)
41-
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
42-
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
43-
44-
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
45-
state_dict[norm1_key] = norm1_weights_or_biases
46-
47-
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
48-
state_dict[norm2_key] = norm2_weights_or_biases
49-
50-
state_dict.pop(key)
51-
52-
53-
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
40+
weights_or_biases = state_dict[key]
41+
norm1_key = f"transformer_blocks.{layer_id}.adaln_modules.1.{weight_or_bias}"
42+
state_dict[norm1_key] = weights_or_biases
5443
state_dict.pop(key)
5544

56-
5745
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
5846
state_dict = saved_dict
5947
if "model" in saved_dict.keys():
@@ -73,16 +61,17 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
7361
"dense_4h_to_h": "2",
7462
".layers": "",
7563
"dense": "to_out.0",
76-
"mixins.patch_embed": "image_patch_embed",
77-
"mixins.adaln.adaln_modules": "adaln_module",
78-
"time_embed": "time_embed",
79-
"label_emb": "label_embed",
80-
"mixins.final_layer.adaln": "final_layer.adaln",
64+
"mixins.patch_embed": "pos_embed",
65+
"time_embed.0": "emb.timestep_embedder.linear_1",
66+
"time_embed.2": "emb.timestep_embedder.linear_2",
67+
"label_emb.0": "emb.label_embedder",
68+
"mixins.final_layer.adaln.1": "norm_out.linear",
8169
"mixins.final_layer.linear": "proj_out",
8270
}
8371

8472
TRANSFORMER_SPECIAL_KEYS_REMAP = {
8573
"query_key_value": reassign_query_key_value_inplace,
74+
"mixins.adaln.adaln_modules": reassign_adaln_norm_inplace,
8675
}
8776

8877
TOKENIZER_MAX_LENGTH = 224
@@ -135,9 +124,8 @@ def convert_transformer(
135124
transformer = CogView3PlusTransformer2DModel(
136125
in_channels=16,
137126
num_layers=num_layers,
138-
num_attention_heads=num_attention_heads,
127+
num_attention_heads=num_attention_heads
139128
).to(dtype=dtype)
140-
141129
for key in list(original_state_dict.keys()):
142130
new_key = key[len(PREFIX_KEY):]
143131
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():

show_model.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
3+
from diffusers import AutoencoderKL
4+
from huggingface_hub import hf_hub_download
5+
from sgm.models.autoencoder import AutoencodingEngine
6+
7+
# (1) create vae_sat
8+
# AutoencodingEngine initialization arguments:
9+
encoder_config={'target': 'sgm.modules.diffusionmodules.model.Encoder', 'params': {'attn_type': 'vanilla', 'double_z': True, 'z_channels': 16, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 4, 8, 8], 'num_res_blocks': 3, 'attn_resolutions': [], 'mid_attn': False, 'dropout': 0.0}}
10+
decoder_config={'target': 'sgm.modules.diffusionmodules.model.Decoder', 'params': {'attn_type': 'vanilla', 'double_z': True, 'z_channels': 16, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 4, 8, 8], 'num_res_blocks': 3, 'attn_resolutions': [], 'mid_attn': False, 'dropout': 0.0}}
11+
loss_config={'target': 'torch.nn.Identity'}
12+
regularizer_config={'target': 'sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer'}
13+
optimizer_config=None
14+
lr_g_factor=1.0
15+
ckpt_path="/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/3plus_ae/imagekl_ch16.pt"
16+
ignore_keys= []
17+
kwargs = {"monitor": "val/rec_loss"}
18+
vae_sat = AutoencodingEngine(
19+
encoder_config=encoder_config,
20+
decoder_config=decoder_config,
21+
loss_config=loss_config,
22+
regularizer_config=regularizer_config,
23+
optimizer_config=optimizer_config,
24+
lr_g_factor=lr_g_factor,
25+
ckpt_path=ckpt_path,
26+
ignore_keys=ignore_keys,
27+
**kwargs)
28+
29+
30+
31+
# (2) create vae (diffusers)
32+
ckpt_path_vae_cogview3 = hf_hub_download(repo_id="ZP2HF/CogView3-SAT", subfolder="3plus_ae", filename="imagekl_ch16.pt")
33+
cogview3_ckpt = torch.load(ckpt_path_vae_cogview3, map_location='cpu')["state_dict"]
34+
35+
in_channels = 3 # Inferred from encoder.conv_in.weight shape
36+
out_channels = 3 # Inferred from decoder.conv_out.weight shape
37+
down_block_types = ("DownEncoderBlock2D",) * 4 # Inferred from the presence of 4 encoder.down blocks
38+
up_block_types = ("UpDecoderBlock2D",) * 4 # Inferred from the presence of 4 decoder.up blocks
39+
block_out_channels = (128, 512, 1024, 1024) # Inferred from the channel sizes in encoder.down blocks
40+
layers_per_block = 3 # Inferred from the number of blocks in each encoder.down and decoder.up
41+
act_fn = "silu" # This is the default, cannot be inferred from state_dict
42+
latent_channels = 16 # Inferred from decoder.conv_in.weight shape
43+
norm_num_groups = 32 # This is the default, cannot be inferred from state_dict
44+
sample_size = 1024 # This is the default, cannot be inferred from state_dict
45+
scaling_factor = 0.18215 # This is the default, cannot be inferred from state_dict
46+
force_upcast = True # This is the default, cannot be inferred from state_dict
47+
use_quant_conv = False # Inferred from the presence of encoder.conv_out
48+
use_post_quant_conv = False # Inferred from the presence of decoder.conv_in
49+
mid_block_add_attention = False # Inferred from the absence of attention layers in mid blocks
50+
51+
vae = AutoencoderKL(
52+
in_channels=in_channels,
53+
out_channels=out_channels,
54+
down_block_types=down_block_types,
55+
up_block_types=up_block_types,
56+
block_out_channels=block_out_channels,
57+
layers_per_block=layers_per_block,
58+
act_fn=act_fn,
59+
latent_channels=latent_channels,
60+
norm_num_groups=norm_num_groups,
61+
sample_size=sample_size,
62+
scaling_factor=scaling_factor,
63+
force_upcast=force_upcast,
64+
use_quant_conv=use_quant_conv,
65+
use_post_quant_conv=use_post_quant_conv,
66+
mid_block_add_attention=mid_block_add_attention,
67+
)
68+
69+
vae.eval()
70+
vae_sat.eval()
71+
72+
converted_vae_state_dict = convert_ldm_vae_checkpoint(cogview3_ckpt, vae.config)
73+
vae.load_state_dict(converted_vae_state_dict, strict=False)
74+
75+
# (3) run forward pass for both models
76+
77+
# [2, 16, 128, 128] -> [2, 3, 1024, 1024
78+
z = torch.load("z.pt").float().to("cpu")
79+
80+
with torch.no_grad():
81+
print(" ")
82+
print(f" running forward pass for diffusers vae")
83+
out = vae.decode(z).sample
84+
print(f" ")
85+
print(f" running forward pass for sgm vae")
86+
out_sat = vae_sat.decode(z)
87+
88+
print(f" output shape: {out.shape}")
89+
print(f" expected output shape: {out_sat.shape}")
90+
assert out.shape == out_sat.shape
91+
assert (out - out_sat).abs().max() < 1e-4, f"max diff: {(out - out_sat).abs().max()}"

show_model_cogview.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
from diffusers import CogView3PlusTransformer2DModel
3+
4+
model = CogView3PlusTransformer2DModel.from_pretrained("/share/home/zyx/Models/CogView3Plus_hf/transformer",torch_dtype=torch.bfloat16)
5+
6+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7+
model.to(device)
8+
9+
batch_size = 1
10+
hidden_states = torch.ones((batch_size, 16, 256, 256), device=device, dtype=torch.bfloat16)
11+
timestep = torch.full((batch_size,), 999.0, device=device, dtype=torch.bfloat16)
12+
y = torch.ones((batch_size, 1536), device=device, dtype=torch.bfloat16)
13+
14+
# 模拟调用 forward 方法
15+
outputs = model(
16+
hidden_states=hidden_states, # hidden_states 输入
17+
timestep=timestep, # timestep 输入
18+
y=y, # 标签输入
19+
block_controlnet_hidden_states=None, # 如果不需要,可以忽略
20+
return_dict=True, # 保持默认值
21+
target_size=[(2048, 2048)],
22+
)
23+
24+
# 输出模型结果
25+
print("Output shape:", outputs.sample.shape)

src/diffusers/models/embeddings.py

Lines changed: 83 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -714,68 +714,68 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
714714
return freqs_cos, freqs_sin
715715

716716

717-
class CogView3PlusPosEmbed(nn.Module):
718-
def __init__(
719-
self,
720-
max_height: int = 128,
721-
max_width: int = 128,
722-
hidden_size: int = 2560,
723-
text_length: int = 0,
724-
block_size: int = 16,
725-
):
726-
super().__init__()
727-
self.max_height = max_height
728-
self.max_width = max_width
729-
self.hidden_size = hidden_size
730-
self.text_length = text_length
731-
self.block_size = block_size
732-
733-
# Initialize the positional embedding as a non-trainable parameter
734-
self.image_pos_embedding = nn.Parameter(
735-
torch.zeros(self.max_height, self.max_width, hidden_size), requires_grad=False
736-
)
737-
# Reinitialize the positional embedding using a sin-cos function
738-
self.reinit()
739-
740-
def forward(self, target_size: List[int]) -> torch.Tensor:
741-
ret = []
742-
for h, w in target_size:
743-
# Scale height and width according to the block size
744-
h, w = h // self.block_size, w // self.block_size
745-
746-
# Reshape the image positional embedding for the target size
747-
image_pos_embed = self.image_pos_embedding[:h, :w].reshape(h * w, -1)
748-
749-
# Combine the text positional embedding and image positional embedding
750-
pos_embed = torch.cat(
751-
[
752-
torch.zeros(
753-
(self.text_length, self.hidden_size),
754-
dtype=image_pos_embed.dtype,
755-
device=image_pos_embed.device,
756-
),
757-
image_pos_embed,
758-
],
759-
dim=0,
760-
)
761-
762-
ret.append(pos_embed[None, ...]) # Add a batch dimension
763-
764-
return torch.cat(ret, dim=0) # Concatenate along the batch dimension
765-
766-
def reinit(self):
767-
# Initialize the positional embedding using the updated 2D sin-cos function
768-
grid_size = (self.max_height, self.max_width)
769-
pos_embed_np = get_2d_sincos_pos_embed(
770-
embed_dim=self.hidden_size,
771-
grid_size=grid_size,
772-
)
773-
774-
# Reshape the positional embedding to the desired shape
775-
pos_embed_np = pos_embed_np.reshape(self.max_height, self.max_width, self.hidden_size)
776-
777-
# Copy the positional embedding data
778-
self.image_pos_embedding.data.copy_(torch.from_numpy(pos_embed_np).float())
717+
# class CogView3PlusPosEmbed(nn.Module):
718+
# def __init__(
719+
# self,
720+
# max_height: int = 128,
721+
# max_width: int = 128,
722+
# hidden_size: int = 2560,
723+
# text_length: int = 0,
724+
# block_size: int = 16,
725+
# ):
726+
# super().__init__()
727+
# self.max_height = max_height
728+
# self.max_width = max_width
729+
# self.hidden_size = hidden_size
730+
# self.text_length = text_length
731+
# self.block_size = block_size
732+
#
733+
# # Initialize the positional embedding as a non-trainable parameter
734+
# self.image_pos_embedding = nn.Parameter(
735+
# torch.zeros(self.max_height, self.max_width, hidden_size), requires_grad=False
736+
# )
737+
# # Reinitialize the positional embedding using a sin-cos function
738+
# self.reinit()
739+
#
740+
# def forward(self, target_size: List[int]) -> torch.Tensor:
741+
# ret = []
742+
# for h, w in target_size:
743+
# # Scale height and width according to the block size
744+
# h, w = h // self.block_size, w // self.block_size
745+
#
746+
# # Reshape the image positional embedding for the target size
747+
# image_pos_embed = self.image_pos_embedding[:h, :w].reshape(h * w, -1)
748+
#
749+
# # Combine the text positional embedding and image positional embedding
750+
# pos_embed = torch.cat(
751+
# [
752+
# torch.zeros(
753+
# (self.text_length, self.hidden_size),
754+
# dtype=image_pos_embed.dtype,
755+
# device=image_pos_embed.device,
756+
# ),
757+
# image_pos_embed,
758+
# ],
759+
# dim=0,
760+
# )
761+
#
762+
# ret.append(pos_embed[None, ...]) # Add a batch dimension
763+
#
764+
# return torch.cat(ret, dim=0) # Concatenate along the batch dimension
765+
#
766+
# def reinit(self):
767+
# # Initialize the positional embedding using the updated 2D sin-cos function
768+
# grid_size = (self.max_height, self.max_width)
769+
# pos_embed_np = get_2d_sincos_pos_embed(
770+
# embed_dim=self.hidden_size,
771+
# grid_size=grid_size,
772+
# )
773+
#
774+
# # Reshape the positional embedding to the desired shape
775+
# pos_embed_np = pos_embed_np.reshape(self.max_height, self.max_width, self.hidden_size)
776+
#
777+
# # Copy the positional embedding data
778+
# self.image_pos_embedding.data.copy_(torch.from_numpy(pos_embed_np).float())
779779

780780

781781
class CogView3PlusImagePatchEmbedding(nn.Module):
@@ -809,8 +809,6 @@ def forward(self, images: torch.Tensor, encoder_outputs: torch.Tensor = None) ->
809809
images = images.view(b, c, h // p1, p1, w // p2, p2)
810810
patches_images = images.permute(0, 2, 4, 1, 3, 5).contiguous()
811811
patches_images = patches_images.view(b, (h // p1) * (w // p2), c * p1 * p2)
812-
813-
# Project the patches
814812
image_emb = self.proj(patches_images)
815813

816814
# If text embeddings are provided, project and concatenate them
@@ -1135,6 +1133,27 @@ def forward(self, image_embeds: torch.Tensor):
11351133
return self.norm(x)
11361134

11371135

1136+
class CogView3CombineTimestepLabelEmbedding(nn.Module):
1137+
def __init__(self, time_embed_dim, label_embed_dim, in_channels=2560):
1138+
super().__init__()
1139+
1140+
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=1)
1141+
self.timestep_embedder = TimestepEmbedding(in_channels=in_channels, time_embed_dim=time_embed_dim)
1142+
self.label_embedder = nn.Sequential(
1143+
nn.Linear(label_embed_dim, time_embed_dim),
1144+
nn.SiLU(),
1145+
nn.Linear(time_embed_dim, time_embed_dim),
1146+
)
1147+
1148+
def forward(self, timestep, class_labels, hidden_dtype=None):
1149+
t_proj = self.time_proj(timestep)
1150+
t_emb = self.timestep_embedder(t_proj.to(dtype=hidden_dtype))
1151+
label_emb = self.label_embedder(class_labels)
1152+
emb = t_emb + label_emb
1153+
1154+
return emb
1155+
1156+
11381157
class CombinedTimestepLabelEmbeddings(nn.Module):
11391158
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
11401159
super().__init__()

0 commit comments

Comments
 (0)