-
Notifications
You must be signed in to change notification settings - Fork 489
Description
Motivation.
This will fix existing bad behaviors in DiT initialization around default values from Diffusers and hopefully add a sustainable pattern for well-typed DiT configs - this will also handle merging parsed config values with default values in a way that keeps the model code as clear as possible.
Proposed Change.
This is mostly an internal change for preventing bugs in config parsing / DiT initialization and making the code more clear, so there will be no changes to accuracy, performance, or APIs.
Summary of Current Problem
There are currently some issues with how configs are being handled, which partially stems from the way in which Diffusers manages its configs compared to Transformers (i.e., through the use of mapping to classes via mixins rather than having easily inspectable PretrainedConfig classes per architecture).
In the current pipeline for Diffusion models, most components are created .from_pretrained, but the Transformer component is not initialized in a way that is consistent. There are generally 3 ways that we initialize it:
- By passing the omni diffusion config and pulling things off thge DiT config from in the class
- By exploding the DiT into kwargs derived from the DiT config
- By relying on defaults in the
__init__of the transformer class
The last of which is particularly problematic, because it causes classes to be initialized incorrectly. The defaults are copied from the corresponding Diffusers class, but since these tend to be overridden from the subconfigs they're loaded, they actually may not be correct.
Example - LongCat
The current LongCat config management illustrates this issue well;
- In the current Image pipeline, we pass the Diffusion config here
self.transformer = LongCatImageTransformer2DModel(od_config=od_config)While the intializer takes many of these values here:
def __init__(
self,
od_config: OmniDiffusionConfig,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 3584,
axes_dims_rope: list[int] = [16, 56, 56],
):
whose kwargs are used to set the attributes on the class, not the passsed transformer config. These values are understandably taken from the default of the corresponding Diffusers class here.
However, AFAIK they are not the right values for any LongCat model; I think they are copied from Flux as the reference architecture. E.g,. the Transformers config for LongCatImage.
{
"_class_name": "LongCatImageTransformer2DModel",
"_diffusers_version": "0.30.0.dev0",
"attention_head_dim": 128,
"guidance_embeds": false,
"in_channels": 64,
"joint_attention_dim": 3584,
"num_attention_heads": 24,
"num_layers": 10,
"num_single_layers": 20,
"patch_size": 1,
"pooled_projection_dim": 3584
}As such, the pipeline ends up incorrectly configured, and in this case much heavier than it should be, because it initializes extra transformer blocks that it doesn't have weights for. I noticed this while working on TeaCache support for LongCat since I was seeing nans and unexpected OOMS while calibrating coefficients through the transformers block. This is the approach currently used for diffusion pipelines.
Proposed Solution
Add a configs dir vllm_omni.diffusion.config which implements a well-defined PretrainedConfig subclass per model, which handles its default kwargs, and can be used for type hints, similar to the custom transformers configs in vLLM here.
More concretely, a small base class could be added for converting from the generic transformers config wrapper:
class BaseDiTConfig(ABC, PretrainedConfig):
_class_name: str | None = None
@classmethod
def from_tf_config(cls, cfg: TransformerConfig):
"""For now - converts the generic wrapper around Transformers Configs that is currently used for DiT
configs to an instance of this subclass.
"""
model_dict = cfg.to_dict()
<validate against the _class_name as needed>
return cls.from_dict(model_dict)With something like this per DiT class
from vllm_omni.diffusion.config.base import BaseDiTConfig
class ZImageTransformer2DModelConfig(BaseDiTConfig):
# Expected _class_name in Diffusers
_class_name = "ZImageTransformer2DModel"
def __init__(
self,
all_patch_size=(2,),
all_f_patch_size=(1,),
...
axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512],
**kwargs,
):
super().__init__()
self.all_patch_size = all_patch_size
self.all_f_patch_size = all_f_patch_size
...
self.axes_dims = axes_dims
self.axes_lens = axes_lensAnd in the corresponding pipeline:
hf_config = ZImageTransformer2DModelConfig.from_tf_config(od_config.tf_model_config)
self.transformer = ZImageTransformer2DModel(
hf_config=hf_config,
od_config=od_config, # Should also be passed for things like quant config etc
)I think that passing the pretrained config subclass directly, as opposed to building it from the omni diffusion config in the DiT is probably a bit cleaner when considering some of the model architectures, e.g., wan2, which may need to build multiple transformers in its pipeline.
self.transformer = get_dit_model(
od_config=od_config,
subfolder="transformer",
local_files_only=local_files_only,
) if load_transformer else None
self.transformer_2 = get_dit_model(
od_config=od_config,
subfolder="transformer_2",
local_files_only=local_files_only,
) if load_transformer_2 else NoneI think this would also be nice for potential follow-ups, e.g,. can push the config conversions a bit earlier on in the process to avoid the current generic transformers wrapper and make the pipeline download calls more well-patterned in the future
Feedback Period.
One week - will also open a draft for discussion in the coming couple of days!
CC List.
@DarkLight1337 @Isotr0py @ywang96 @sfeng33 @NickLucche @hsliuustc0106 @Gaohan123 @tzhouam @ZJY0516 in case anyone has thoughts
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.