|
40 | 40 | import diffusers |
41 | 41 | from diffusers import ( |
42 | 42 | AutoencoderKL, |
43 | | - FlowMatchEulerDiscreteScheduler, |
44 | 43 | CogView4ControlPipeline, |
45 | 44 | CogView4Transformer2DModel, |
| 45 | + FlowMatchEulerDiscreteScheduler, |
46 | 46 | ) |
47 | 47 | from diffusers.optimization import get_scheduler |
48 | 48 | from diffusers.training_utils import ( |
@@ -977,6 +977,7 @@ def load_model_hook(models, input_dir): |
977 | 977 | text_encoding_pipeline = CogView4ControlPipeline.from_pretrained( |
978 | 978 | args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype |
979 | 979 | ) |
| 980 | + tokenizer = text_encoding_pipeline.tokenizer |
980 | 981 |
|
981 | 982 | # Potentially load in the weights and states from a previous save |
982 | 983 | if args.resume_from_checkpoint: |
@@ -1043,6 +1044,16 @@ def load_model_hook(models, input_dir): |
1043 | 1044 | with accelerator.accumulate(cogview4_transformer): |
1044 | 1045 | # Convert images to latent space |
1045 | 1046 | # vae encode |
| 1047 | + prompts = batch["captions"] |
| 1048 | + attention_mask = tokenizer( |
| 1049 | + prompts, |
| 1050 | + padding="longest", # not use max length |
| 1051 | + max_length=args.max_sequence_length, |
| 1052 | + truncation=True, |
| 1053 | + add_special_tokens=True, |
| 1054 | + return_tensors="pt", |
| 1055 | + ).attention_mask.float() |
| 1056 | + |
1046 | 1057 | pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype) |
1047 | 1058 | control_latents = encode_images( |
1048 | 1059 | batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype |
@@ -1119,6 +1130,7 @@ def load_model_hook(models, input_dir): |
1119 | 1130 | target_size=target_size, |
1120 | 1131 | crop_coords=crops_coords_top_left, |
1121 | 1132 | return_dict=False, |
| 1133 | + attention_mask=attention_mask, |
1122 | 1134 | )[0] |
1123 | 1135 | # these weighting schemes use a uniform timestep sampling |
1124 | 1136 | # and instead post-weight the loss |
|
0 commit comments