@@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
33103310 checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
33113311
33123312 return checkpoint
3313+
3314+
3315+ def convert_chroma_transformer_checkpoint_to_diffusers (checkpoint , ** kwargs ):
3316+ converted_state_dict = {}
3317+ keys = list (checkpoint .keys ())
3318+
3319+ for k in keys :
3320+ if "model.diffusion_model." in k :
3321+ checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
3322+
3323+ num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "double_blocks." in k ))[- 1 ] + 1 # noqa: C401
3324+ num_single_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "single_blocks." in k ))[- 1 ] + 1 # noqa: C401
3325+ num_guidance_layers = (
3326+ list (set (int (k .split ("." , 3 )[2 ]) for k in checkpoint if "distilled_guidance_layer.layers." in k ))[- 1 ] + 1 # noqa: C401
3327+ )
3328+ mlp_ratio = 4.0
3329+ inner_dim = 3072
3330+
3331+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
3332+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
3333+ def swap_scale_shift (weight ):
3334+ shift , scale = weight .chunk (2 , dim = 0 )
3335+ new_weight = torch .cat ([scale , shift ], dim = 0 )
3336+ return new_weight
3337+
3338+ # guidance
3339+ converted_state_dict ["distilled_guidance_layer.in_proj.bias" ] = checkpoint .pop (
3340+ "distilled_guidance_layer.in_proj.bias"
3341+ )
3342+ converted_state_dict ["distilled_guidance_layer.in_proj.weight" ] = checkpoint .pop (
3343+ "distilled_guidance_layer.in_proj.weight"
3344+ )
3345+ converted_state_dict ["distilled_guidance_layer.out_proj.bias" ] = checkpoint .pop (
3346+ "distilled_guidance_layer.out_proj.bias"
3347+ )
3348+ converted_state_dict ["distilled_guidance_layer.out_proj.weight" ] = checkpoint .pop (
3349+ "distilled_guidance_layer.out_proj.weight"
3350+ )
3351+ for i in range (num_guidance_layers ):
3352+ block_prefix = f"distilled_guidance_layer.layers.{ i } ."
3353+ converted_state_dict [f"{ block_prefix } linear_1.bias" ] = checkpoint .pop (
3354+ f"distilled_guidance_layer.layers.{ i } .in_layer.bias"
3355+ )
3356+ converted_state_dict [f"{ block_prefix } linear_1.weight" ] = checkpoint .pop (
3357+ f"distilled_guidance_layer.layers.{ i } .in_layer.weight"
3358+ )
3359+ converted_state_dict [f"{ block_prefix } linear_2.bias" ] = checkpoint .pop (
3360+ f"distilled_guidance_layer.layers.{ i } .out_layer.bias"
3361+ )
3362+ converted_state_dict [f"{ block_prefix } linear_2.weight" ] = checkpoint .pop (
3363+ f"distilled_guidance_layer.layers.{ i } .out_layer.weight"
3364+ )
3365+ converted_state_dict [f"distilled_guidance_layer.norms.{ i } .weight" ] = checkpoint .pop (
3366+ f"distilled_guidance_layer.norms.{ i } .scale"
3367+ )
3368+
3369+ # context_embedder
3370+ converted_state_dict ["context_embedder.weight" ] = checkpoint .pop ("txt_in.weight" )
3371+ converted_state_dict ["context_embedder.bias" ] = checkpoint .pop ("txt_in.bias" )
3372+
3373+ # x_embedder
3374+ converted_state_dict ["x_embedder.weight" ] = checkpoint .pop ("img_in.weight" )
3375+ converted_state_dict ["x_embedder.bias" ] = checkpoint .pop ("img_in.bias" )
3376+
3377+ # double transformer blocks
3378+ for i in range (num_layers ):
3379+ block_prefix = f"transformer_blocks.{ i } ."
3380+ # Q, K, V
3381+ sample_q , sample_k , sample_v = torch .chunk (checkpoint .pop (f"double_blocks.{ i } .img_attn.qkv.weight" ), 3 , dim = 0 )
3382+ context_q , context_k , context_v = torch .chunk (
3383+ checkpoint .pop (f"double_blocks.{ i } .txt_attn.qkv.weight" ), 3 , dim = 0
3384+ )
3385+ sample_q_bias , sample_k_bias , sample_v_bias = torch .chunk (
3386+ checkpoint .pop (f"double_blocks.{ i } .img_attn.qkv.bias" ), 3 , dim = 0
3387+ )
3388+ context_q_bias , context_k_bias , context_v_bias = torch .chunk (
3389+ checkpoint .pop (f"double_blocks.{ i } .txt_attn.qkv.bias" ), 3 , dim = 0
3390+ )
3391+ converted_state_dict [f"{ block_prefix } attn.to_q.weight" ] = torch .cat ([sample_q ])
3392+ converted_state_dict [f"{ block_prefix } attn.to_q.bias" ] = torch .cat ([sample_q_bias ])
3393+ converted_state_dict [f"{ block_prefix } attn.to_k.weight" ] = torch .cat ([sample_k ])
3394+ converted_state_dict [f"{ block_prefix } attn.to_k.bias" ] = torch .cat ([sample_k_bias ])
3395+ converted_state_dict [f"{ block_prefix } attn.to_v.weight" ] = torch .cat ([sample_v ])
3396+ converted_state_dict [f"{ block_prefix } attn.to_v.bias" ] = torch .cat ([sample_v_bias ])
3397+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.weight" ] = torch .cat ([context_q ])
3398+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.bias" ] = torch .cat ([context_q_bias ])
3399+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.weight" ] = torch .cat ([context_k ])
3400+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.bias" ] = torch .cat ([context_k_bias ])
3401+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.weight" ] = torch .cat ([context_v ])
3402+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.bias" ] = torch .cat ([context_v_bias ])
3403+ # qk_norm
3404+ converted_state_dict [f"{ block_prefix } attn.norm_q.weight" ] = checkpoint .pop (
3405+ f"double_blocks.{ i } .img_attn.norm.query_norm.scale"
3406+ )
3407+ converted_state_dict [f"{ block_prefix } attn.norm_k.weight" ] = checkpoint .pop (
3408+ f"double_blocks.{ i } .img_attn.norm.key_norm.scale"
3409+ )
3410+ converted_state_dict [f"{ block_prefix } attn.norm_added_q.weight" ] = checkpoint .pop (
3411+ f"double_blocks.{ i } .txt_attn.norm.query_norm.scale"
3412+ )
3413+ converted_state_dict [f"{ block_prefix } attn.norm_added_k.weight" ] = checkpoint .pop (
3414+ f"double_blocks.{ i } .txt_attn.norm.key_norm.scale"
3415+ )
3416+ # ff img_mlp
3417+ converted_state_dict [f"{ block_prefix } ff.net.0.proj.weight" ] = checkpoint .pop (
3418+ f"double_blocks.{ i } .img_mlp.0.weight"
3419+ )
3420+ converted_state_dict [f"{ block_prefix } ff.net.0.proj.bias" ] = checkpoint .pop (f"double_blocks.{ i } .img_mlp.0.bias" )
3421+ converted_state_dict [f"{ block_prefix } ff.net.2.weight" ] = checkpoint .pop (f"double_blocks.{ i } .img_mlp.2.weight" )
3422+ converted_state_dict [f"{ block_prefix } ff.net.2.bias" ] = checkpoint .pop (f"double_blocks.{ i } .img_mlp.2.bias" )
3423+ converted_state_dict [f"{ block_prefix } ff_context.net.0.proj.weight" ] = checkpoint .pop (
3424+ f"double_blocks.{ i } .txt_mlp.0.weight"
3425+ )
3426+ converted_state_dict [f"{ block_prefix } ff_context.net.0.proj.bias" ] = checkpoint .pop (
3427+ f"double_blocks.{ i } .txt_mlp.0.bias"
3428+ )
3429+ converted_state_dict [f"{ block_prefix } ff_context.net.2.weight" ] = checkpoint .pop (
3430+ f"double_blocks.{ i } .txt_mlp.2.weight"
3431+ )
3432+ converted_state_dict [f"{ block_prefix } ff_context.net.2.bias" ] = checkpoint .pop (
3433+ f"double_blocks.{ i } .txt_mlp.2.bias"
3434+ )
3435+ # output projections.
3436+ converted_state_dict [f"{ block_prefix } attn.to_out.0.weight" ] = checkpoint .pop (
3437+ f"double_blocks.{ i } .img_attn.proj.weight"
3438+ )
3439+ converted_state_dict [f"{ block_prefix } attn.to_out.0.bias" ] = checkpoint .pop (
3440+ f"double_blocks.{ i } .img_attn.proj.bias"
3441+ )
3442+ converted_state_dict [f"{ block_prefix } attn.to_add_out.weight" ] = checkpoint .pop (
3443+ f"double_blocks.{ i } .txt_attn.proj.weight"
3444+ )
3445+ converted_state_dict [f"{ block_prefix } attn.to_add_out.bias" ] = checkpoint .pop (
3446+ f"double_blocks.{ i } .txt_attn.proj.bias"
3447+ )
3448+
3449+ # single transformer blocks
3450+ for i in range (num_single_layers ):
3451+ block_prefix = f"single_transformer_blocks.{ i } ."
3452+ # Q, K, V, mlp
3453+ mlp_hidden_dim = int (inner_dim * mlp_ratio )
3454+ split_size = (inner_dim , inner_dim , inner_dim , mlp_hidden_dim )
3455+ q , k , v , mlp = torch .split (checkpoint .pop (f"single_blocks.{ i } .linear1.weight" ), split_size , dim = 0 )
3456+ q_bias , k_bias , v_bias , mlp_bias = torch .split (
3457+ checkpoint .pop (f"single_blocks.{ i } .linear1.bias" ), split_size , dim = 0
3458+ )
3459+ converted_state_dict [f"{ block_prefix } attn.to_q.weight" ] = torch .cat ([q ])
3460+ converted_state_dict [f"{ block_prefix } attn.to_q.bias" ] = torch .cat ([q_bias ])
3461+ converted_state_dict [f"{ block_prefix } attn.to_k.weight" ] = torch .cat ([k ])
3462+ converted_state_dict [f"{ block_prefix } attn.to_k.bias" ] = torch .cat ([k_bias ])
3463+ converted_state_dict [f"{ block_prefix } attn.to_v.weight" ] = torch .cat ([v ])
3464+ converted_state_dict [f"{ block_prefix } attn.to_v.bias" ] = torch .cat ([v_bias ])
3465+ converted_state_dict [f"{ block_prefix } proj_mlp.weight" ] = torch .cat ([mlp ])
3466+ converted_state_dict [f"{ block_prefix } proj_mlp.bias" ] = torch .cat ([mlp_bias ])
3467+ # qk norm
3468+ converted_state_dict [f"{ block_prefix } attn.norm_q.weight" ] = checkpoint .pop (
3469+ f"single_blocks.{ i } .norm.query_norm.scale"
3470+ )
3471+ converted_state_dict [f"{ block_prefix } attn.norm_k.weight" ] = checkpoint .pop (
3472+ f"single_blocks.{ i } .norm.key_norm.scale"
3473+ )
3474+ # output projections.
3475+ converted_state_dict [f"{ block_prefix } proj_out.weight" ] = checkpoint .pop (f"single_blocks.{ i } .linear2.weight" )
3476+ converted_state_dict [f"{ block_prefix } proj_out.bias" ] = checkpoint .pop (f"single_blocks.{ i } .linear2.bias" )
3477+
3478+ converted_state_dict ["proj_out.weight" ] = checkpoint .pop ("final_layer.linear.weight" )
3479+ converted_state_dict ["proj_out.bias" ] = checkpoint .pop ("final_layer.linear.bias" )
3480+
3481+ return converted_state_dict
0 commit comments