Skip to content

[RFC]: Fix + Standardize DiT Config/Init Patterns #1456

@alex-jw-brooks

Description

@alex-jw-brooks

Motivation.

This will fix existing bad behaviors in DiT initialization around default values from Diffusers and hopefully add a sustainable pattern for well-typed DiT configs - this will also handle merging parsed config values with default values in a way that keeps the model code as clear as possible.

Proposed Change.

This is mostly an internal change for preventing bugs in config parsing / DiT initialization and making the code more clear, so there will be no changes to accuracy, performance, or APIs.

Summary of Current Problem

There are currently some issues with how configs are being handled, which partially stems from the way in which Diffusers manages its configs compared to Transformers (i.e., through the use of mapping to classes via mixins rather than having easily inspectable PretrainedConfig classes per architecture).

In the current pipeline for Diffusion models, most components are created .from_pretrained, but the Transformer component is not initialized in a way that is consistent. There are generally 3 ways that we initialize it:

  1. By passing the omni diffusion config and pulling things off thge DiT config from in the class
  2. By exploding the DiT into kwargs derived from the DiT config
  3. By relying on defaults in the __init__ of the transformer class

The last of which is particularly problematic, because it causes classes to be initialized incorrectly. The defaults are copied from the corresponding Diffusers class, but since these tend to be overridden from the subconfigs they're loaded, they actually may not be correct.

Example - LongCat

The current LongCat config management illustrates this issue well;

  • In the current Image pipeline, we pass the Diffusion config here
self.transformer = LongCatImageTransformer2DModel(od_config=od_config)

While the intializer takes many of these values here:

    def __init__(
        self,
        od_config: OmniDiffusionConfig,
        patch_size: int = 1,
        in_channels: int = 64,
        num_layers: int = 19,
        num_single_layers: int = 38,
        attention_head_dim: int = 128,
        num_attention_heads: int = 24,
        joint_attention_dim: int = 3584,
        pooled_projection_dim: int = 3584,
        axes_dims_rope: list[int] = [16, 56, 56],
    ):

whose kwargs are used to set the attributes on the class, not the passsed transformer config. These values are understandably taken from the default of the corresponding Diffusers class here.

However, AFAIK they are not the right values for any LongCat model; I think they are copied from Flux as the reference architecture. E.g,. the Transformers config for LongCatImage.

{
    "_class_name": "LongCatImageTransformer2DModel",
    "_diffusers_version": "0.30.0.dev0",
    "attention_head_dim": 128,
    "guidance_embeds": false,
    "in_channels": 64,
    "joint_attention_dim": 3584,
    "num_attention_heads": 24,
    "num_layers": 10,
    "num_single_layers": 20,
    "patch_size": 1,
    "pooled_projection_dim": 3584
}

As such, the pipeline ends up incorrectly configured, and in this case much heavier than it should be, because it initializes extra transformer blocks that it doesn't have weights for. I noticed this while working on TeaCache support for LongCat since I was seeing nans and unexpected OOMS while calibrating coefficients through the transformers block. This is the approach currently used for diffusion pipelines.

Proposed Solution

Add a configs dir vllm_omni.diffusion.config which implements a well-defined PretrainedConfig subclass per model, which handles its default kwargs, and can be used for type hints, similar to the custom transformers configs in vLLM here.

More concretely, a small base class could be added for converting from the generic transformers config wrapper:

class BaseDiTConfig(ABC, PretrainedConfig):
    _class_name: str | None = None

    @classmethod
    def from_tf_config(cls, cfg: TransformerConfig):
        """For now - converts the generic wrapper around Transformers Configs that is currently used for DiT
        configs to an instance of this subclass.
        """
        model_dict = cfg.to_dict()
        <validate against the _class_name as needed>
        return cls.from_dict(model_dict)

With something like this per DiT class

from vllm_omni.diffusion.config.base import BaseDiTConfig

class ZImageTransformer2DModelConfig(BaseDiTConfig):
    # Expected _class_name in Diffusers
    _class_name = "ZImageTransformer2DModel"

    def __init__(
        self,
        all_patch_size=(2,),
        all_f_patch_size=(1,),
        ...
        axes_dims=[32, 48, 48],
        axes_lens=[1024, 512, 512],
        **kwargs,
    ):
        super().__init__()
        self.all_patch_size = all_patch_size
        self.all_f_patch_size = all_f_patch_size
        ...
        self.axes_dims = axes_dims
        self.axes_lens = axes_lens

And in the corresponding pipeline:

        hf_config = ZImageTransformer2DModelConfig.from_tf_config(od_config.tf_model_config)
        self.transformer = ZImageTransformer2DModel(
            hf_config=hf_config,
            od_config=od_config, # Should also be passed for things like quant config etc
        )

I think that passing the pretrained config subclass directly, as opposed to building it from the omni diffusion config in the DiT is probably a bit cleaner when considering some of the model architectures, e.g., wan2, which may need to build multiple transformers in its pipeline.

        self.transformer = get_dit_model(
            od_config=od_config,
            subfolder="transformer",
            local_files_only=local_files_only,
        ) if load_transformer else None

        self.transformer_2 = get_dit_model(
            od_config=od_config,
            subfolder="transformer_2",
            local_files_only=local_files_only,
        ) if load_transformer_2 else None

I think this would also be nice for potential follow-ups, e.g,. can push the config conversions a bit earlier on in the process to avoid the current generic transformers wrapper and make the pipeline download calls more well-patterned in the future

Feedback Period.

One week - will also open a draft for discussion in the coming couple of days!

CC List.

@DarkLight1337 @Isotr0py @ywang96 @sfeng33 @NickLucche @hsliuustc0106 @Gaohan123 @tzhouam @ZJY0516 in case anyone has thoughts

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions