Skip to content

Commit 5d2e994

Browse files
new loss
1 parent b9d864b commit 5d2e994

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from diffusers.optimization import get_scheduler
4242
from diffusers.training_utils import (
4343
compute_density_for_timestep_sampling,
44-
compute_loss_weighting_for_sd3,
4544
free_memory,
4645
)
4746
from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid
@@ -804,7 +803,7 @@ def main(args):
804803
cogview4_transformer.patch_embed.proj = new_linear
805804

806805
assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0)
807-
cogview4_transformer.register_to_config(in_channels=cogview4_transformer.config.in_channels * 2, out_channels=initial_input_channels)
806+
cogview4_transformer.register_to_config(in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels)
808807

809808
if args.only_target_transformer_blocks:
810809
cogview4_transformer.patch_embed.proj.requires_grad_(True)
@@ -1097,7 +1096,7 @@ def load_model_hook(models, input_dir):
10971096

10981097
# these weighting schemes use a uniform timestep sampling
10991098
# and instead post-weight the loss
1100-
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1099+
weighting = (sigmas**-2.0).float()
11011100
# flow-matching loss
11021101
target = noise - pixel_latents
11031102

0 commit comments

Comments
 (0)