Skip to content

Commit 7a68a3e

Browse files
use imagetoken
1 parent 940c23b commit 7a68a3e

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,9 +1056,7 @@ def load_model_hook(models, input_dir):
10561056
timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
10571057
sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device)
10581058
captions = batch["captions"]
1059-
token_lengths = [len(caption.split()) for caption in captions]
1060-
token_per_sample = max(token_lengths)
1061-
image_seq_lens = torch.tensor(token_per_sample // patch_size ** 2, dtype=pixel_latents.dtype, device=pixel_latents.device)
1059+
image_seq_lens = torch.tensor(pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size ** 2, dtype=pixel_latents.dtype, device=pixel_latents.device) # H * W / VAE patch_size
10621060
mu = torch.sqrt(image_seq_lens / 256)
10631061
mu = mu * 0.75 + 0.25
10641062
scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(dtype=pixel_latents.dtype, device=pixel_latents.device)

0 commit comments

Comments
 (0)