Skip to content

Commit 6090ea7

Browse files
draft schedule
1 parent 6163679 commit 6090ea7

File tree

6 files changed

+589
-27
lines changed

6 files changed

+589
-27
lines changed

scripts/convert_cogview4_to_diffusers.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@
3131
from accelerate import init_empty_weights
3232
from transformers import PreTrainedTokenizerFast, GlmForCausalLM
3333

34-
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView4Pipeline, CogView3PlusTransformer2DModel
34+
from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView3PlusTransformer2DModel
3535
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
3636
from diffusers.utils.import_utils import is_accelerate_available
3737

38-
3938
CTX = init_empty_weights if is_accelerate_available() else nullcontext
4039

4140
parser = argparse.ArgumentParser()
@@ -170,16 +169,16 @@ def main(args):
170169
args.transformer_checkpoint_path
171170
)
172171
transformer = CogView3PlusTransformer2DModel(
173-
patch_size = 2,
174-
in_channels = 16,
175-
num_layers = 28,
176-
attention_head_dim= 128,
177-
num_attention_heads = 32,
178-
out_channels = 16,
179-
text_embed_dim= 4096,
180-
time_embed_dim = 512,
181-
condition_dim= 256,
182-
pos_embed_max_size = 128,
172+
patch_size=2,
173+
in_channels=16,
174+
num_layers=28,
175+
attention_head_dim=128,
176+
num_attention_heads=32,
177+
out_channels=16,
178+
text_embed_dim=4096,
179+
time_embed_dim=512,
180+
condition_dim=256,
181+
pos_embed_max_size=128,
183182
)
184183
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
185184
if dtype is not None:
@@ -210,16 +209,20 @@ def main(args):
210209
if dtype is not None:
211210
vae = vae.to(dtype=dtype)
212211

213-
text_encoder_id = 'THUDM/glm-4-9b-hf'
212+
text_encoder_id = "THUDM/glm-4-9b-hf"
214213
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
215-
text_encoder = GlmForCausalLM.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir, torch_dtype=torch.bfloat16 if dtype=="bf16" else torch.float32)
214+
text_encoder = GlmForCausalLM.from_pretrained(
215+
text_encoder_id,
216+
cache_dir=args.text_encoder_cache_dir,
217+
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
218+
)
216219
# Apparently, the conversion does not work anymore without this :shrug:
217220
for param in text_encoder.parameters():
218221
param.data = param.data.contiguous()
219222

220-
scheduler = CogVideoXDDIMScheduler.from_config(
223+
scheduler = CogView4DDIMScheduler.from_config(
221224
{
222-
"snr_shift_scale": 4.0,
225+
"shift_scale": 1.0,
223226
"beta_end": 0.012,
224227
"beta_schedule": "scaled_linear",
225228
"beta_start": 0.00085,

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
"CMStochasticIterativeScheduler",
176176
"CogVideoXDDIMScheduler",
177177
"CogVideoXDPMScheduler",
178+
"CogView4DDIMScheduler",
178179
"DDIMInverseScheduler",
179180
"DDIMParallelScheduler",
180181
"DDIMScheduler",
@@ -684,6 +685,7 @@
684685
CMStochasticIterativeScheduler,
685686
CogVideoXDDIMScheduler,
686687
CogVideoXDPMScheduler,
688+
CogView4DDIMScheduler,
687689
DDIMInverseScheduler,
688690
DDIMParallelScheduler,
689691
DDIMScheduler,

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...image_processor import VaeImageProcessor
2424
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
2525
from ...pipelines.pipeline_utils import DiffusionPipeline
26-
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
26+
from ...schedulers import CogView4DDIMScheduler
2727
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2828
from ...utils.torch_utils import randn_tensor
2929
from .pipeline_output import CogView4PipelineOutput
@@ -151,7 +151,7 @@ def __init__(
151151
text_encoder: GlmModel,
152152
vae: AutoencoderKL,
153153
transformer: CogView3PlusTransformer2DModel,
154-
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
154+
scheduler: CogView4DDIMScheduler,
155155
):
156156
super().__init__()
157157

@@ -318,7 +318,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
318318
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
319319
else:
320320
latents = latents.to(device)
321-
322321
# scale the initial noise by the standard deviation required by the scheduler
323322
latents = latents * self.scheduler.init_noise_sigma
324323
return latents
@@ -517,8 +516,8 @@ def __call__(
517516
Examples:
518517
519518
Returns:
520-
[`~pipelines.cogview3.pipeline_CogView4.CogView3PipelineOutput`] or `tuple`:
521-
[`~pipelines.cogview3.pipeline_CogView4.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
519+
[`~pipelines.cogview4.pipeline_CogView4.CogView3PipelineOutput`] or `tuple`:
520+
[`~pipelines.cogview4.pipeline_CogView4.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
522521
`tuple`. When returning a tuple, the first element is a list with the generated images.
523522
"""
524523

@@ -640,15 +639,13 @@ def __call__(
640639
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
641640

642641
# compute the previous noisy sample x_t -> x_t-1
643-
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
642+
if not isinstance(self.scheduler, CogView4DDIMScheduler):
644643
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
645644
else:
646645
latents, old_pred_original_sample = self.scheduler.step(
647-
noise_pred,
648-
old_pred_original_sample,
649-
t,
650-
timesteps[i - 1] if i > 0 else None,
651-
latents,
646+
model_output=noise_pred,
647+
timestep=t,
648+
sample=latents,
652649
**extra_step_kwargs,
653650
return_dict=False,
654651
)

src/diffusers/schedulers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
4545
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
4646
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
47+
_import_structure["scheduling_ddim_cogview4"] = ["CogView4DDIMScheduler"]
4748
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
4849
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
4950
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
@@ -144,6 +145,7 @@
144145
from .scheduling_consistency_models import CMStochasticIterativeScheduler
145146
from .scheduling_ddim import DDIMScheduler
146147
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
148+
from .scheduling_ddim_cogview4 import CogView4DDIMScheduler
147149
from .scheduling_ddim_inverse import DDIMInverseScheduler
148150
from .scheduling_ddim_parallel import DDIMParallelScheduler
149151
from .scheduling_ddpm import DDPMScheduler

0 commit comments

Comments
 (0)