Skip to content

Commit 9a10ceb

Browse files
committed
[fix] Add attention mask for padded token
1 parent 264060e commit 9a10ceb

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
import diffusers
4141
from diffusers import (
4242
AutoencoderKL,
43-
FlowMatchEulerDiscreteScheduler,
4443
CogView4ControlPipeline,
4544
CogView4Transformer2DModel,
45+
FlowMatchEulerDiscreteScheduler,
4646
)
4747
from diffusers.optimization import get_scheduler
4848
from diffusers.training_utils import (
@@ -977,6 +977,7 @@ def load_model_hook(models, input_dir):
977977
text_encoding_pipeline = CogView4ControlPipeline.from_pretrained(
978978
args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype
979979
)
980+
tokenizer = text_encoding_pipeline.tokenizer
980981

981982
# Potentially load in the weights and states from a previous save
982983
if args.resume_from_checkpoint:
@@ -1043,6 +1044,16 @@ def load_model_hook(models, input_dir):
10431044
with accelerator.accumulate(cogview4_transformer):
10441045
# Convert images to latent space
10451046
# vae encode
1047+
prompts = batch["captions"]
1048+
attention_mask = tokenizer(
1049+
prompts,
1050+
padding="longest", # not use max length
1051+
max_length=args.max_sequence_length,
1052+
truncation=True,
1053+
add_special_tokens=True,
1054+
return_tensors="pt",
1055+
).attention_mask.float()
1056+
10461057
pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype)
10471058
control_latents = encode_images(
10481059
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
@@ -1119,6 +1130,7 @@ def load_model_hook(models, input_dir):
11191130
target_size=target_size,
11201131
crop_coords=crops_coords_top_left,
11211132
return_dict=False,
1133+
attention_mask=attention_mask,
11221134
)[0]
11231135
# these weighting schemes use a uniform timestep sampling
11241136
# and instead post-weight the loss

0 commit comments

Comments
 (0)