diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index b17e30ce2717..1503aacc05dc 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -64,8 +64,25 @@ def __init__( self.linear = nn.Linear(embedding_dim, output_dim) self.linear.weight.data[:embedding_dim, :] = torch.eye(embedding_dim) self.linear.weight.data[embedding_dim:, :] = torch.eye(embedding_dim) + # Initialize bias to match the original scale_shift_table initialization pattern + self.linear.bias.data = torch.randn(output_dim) / embedding_dim**0.5 self.norm = FP32LayerNorm(embedding_dim, norm_eps, norm_elementwise_affine) + def load_from_scale_shift_table(self, scale_shift_table: torch.Tensor) -> None: + """ + Helper method to transfer scale_shift_table values from old model checkpoints. + This can be used to migrate models saved with the old format to the new AdaLayerNorm format. + + Args: + scale_shift_table: Tensor of shape (1, 2, embedding_dim) from old model format + """ + if scale_shift_table.shape[0] != 1 or scale_shift_table.shape[1] != 2: + raise ValueError(f"Expected scale_shift_table shape (1, 2, embedding_dim), got {scale_shift_table.shape}") + + with torch.no_grad(): + # Flatten the scale_shift_table to match the bias shape + self.linear.bias.data = scale_shift_table.view(-1) + def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: if temb.ndim == 2: # If temb is 2D, we assume it has 1-D time embedding values for each batch. @@ -443,7 +460,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["SkyReelsV2TransformerBlock", "SkyReelsV2AdaLayerNorm"] - _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3", "norm_out"] _keys_to_ignore_on_load_unexpected = ["norm_added_q", "scale_shift_table"] _keys_to_ignore_on_load_missing = ["norm_out.linear.weight", "norm_out.linear.bias"]