Skip to content

Commit dff4b29

Browse files
train con and uc
1 parent 1d91a24 commit dff4b29

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,6 @@ def main(args):
758758
revision=args.revision,
759759
variant=args.variant,
760760
)
761-
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
762761
cogview4_transformer = CogView4Transformer2DModel.from_pretrained(
763762
args.pretrained_model_name_or_path,
764763
subfolder="transformer",
@@ -1081,9 +1080,8 @@ def load_model_hook(models, input_dir):
10811080
#TODO: Should a parameter be set here for passing? This is not present in Flux.
10821081
crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
10831082
crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1)
1084-
10851083
# Predict.
1086-
model_pred = cogview4_transformer(
1084+
noise_pred_cond = cogview4_transformer(
10871085
hidden_states=concatenated_noisy_model_input,
10881086
encoder_hidden_states=prompt_embeds,
10891087
timestep=timesteps,
@@ -1093,6 +1091,16 @@ def load_model_hook(models, input_dir):
10931091
return_dict=False,
10941092
)[0]
10951093

1094+
noise_pred_uncond = cogview4_transformer(
1095+
hidden_states=concatenated_noisy_model_input,
1096+
encoder_hidden_states=pooled_prompt_embeds,
1097+
timestep=timesteps,
1098+
original_size=original_size,
1099+
target_size=target_size,
1100+
crop_coords=crops_coords_top_left,
1101+
return_dict=False,
1102+
)[0]
1103+
model_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
10961104
# these weighting schemes use a uniform timestep sampling
10971105
# and instead post-weight the loss
10981106
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)

0 commit comments

Comments
 (0)