Skip to content

Commit a9f448e

Browse files
1
1 parent 44bfd4c commit a9f448e

File tree

2 files changed

+0
-477
lines changed

2 files changed

+0
-477
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,18 +1025,6 @@ def load_model_hook(models, input_dir):
10251025
disable=not accelerator.is_local_main_process,
10261026
)
10271027

1028-
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1029-
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1030-
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1031-
timesteps = timesteps.to(accelerator.device)
1032-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1033-
1034-
sigma = sigmas[step_indices].flatten()
1035-
while len(sigma.shape) < n_dim:
1036-
sigma = sigma.unsqueeze(-1)
1037-
return sigma
1038-
1039-
image_logs = None
10401028
for epoch in range(first_epoch, args.num_train_epochs):
10411029
cogview4_transformer.train()
10421030
for step, batch in enumerate(train_dataloader):

0 commit comments

Comments
 (0)