Skip to content

Commit c0b2e5a

Browse files
authored
[HF] Model Definition Conversion Support for FLUX (#1582)
This PR adds the `FluxStateDictAdapter`, allowing us to convert checkpoints to and from HF. Additional changes: - Modifies `download_hf_assets` script to support downloading diffusion-type safetensor files - Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf` so that conversion script can be reused - e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev` Tests: Performing KL divergence test on the forward pass of converted weights loaded in `torchtitan` and HF weights loaded with HF `FluxTransformer2DModel`, we get: ``` Average loss for test from_hf is 7.233546986222528e-13 ``` Addiitonally, we can now run inference with HF weights to verify changes made in #1548 ### Batched Inference on TorchTitan: | | prompt0 | prompt1 | prompt2 | | --- | --- | --- | --- | | no CFG | <img width="1024" height="1024" alt="prompt0_nocfg" src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd" /> | <img width="1024" height="1024" alt="prompt1_nocfg" src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40" /> | <img width="1024" height="1024" alt="prompt2_nocfg" src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9" /> | | CFG | <img width="1024" height="1024" alt="prompt0_cfg" src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c" /> | <img width="1024" height="1024" alt="prompt1_cfg" src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c" /> | <img width="1024" height="1024" alt="prompt2_cfg" src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee" /> |
1 parent 9874e84 commit c0b2e5a

File tree

7 files changed

+305
-4
lines changed

7 files changed

+305
-4
lines changed

scripts/checkpoint_conversion/convert_from_hf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
@torch.inference_mode()
1818
def convert_from_hf(input_dir, output_dir, model_name, model_flavor):
19+
if model_name == "flux":
20+
import torchtitan.experiments.flux # noqa: F401
1921
# initialize model to allocate memory for state dict
2022
train_spec = train_spec_module.get_train_spec(model_name)
2123
model_args = train_spec.model_args[model_flavor]

scripts/checkpoint_conversion/convert_to_hf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
@torch.inference_mode()
1818
def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_path):
19+
if model_name == "flux":
20+
import torchtitan.experiments.flux # noqa: F401
1921
# load model and model args so that we can get the state dict shape
2022
train_spec = train_spec_module.get_train_spec(model_name)
2123
model_args = train_spec.model_args[model_flavor]

scripts/download_hf_assets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def download_hf_assets(
7676
"merges.txt",
7777
"special_tokens_map.json",
7878
],
79-
"safetensors": ["*.safetensors", "model.safetensors.index.json"],
80-
"index": ["model.safetensors.index.json"],
79+
"safetensors": ["*.safetensors", "*model.safetensors.index.json"],
80+
"index": ["*model.safetensors.index.json"],
8181
"config": ["config.json", "generation_config.json"],
8282
}
8383

torchtitan/experiments/flux/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .model.args import FluxModelArgs
1818
from .model.autoencoder import AutoEncoderParams
1919
from .model.model import FluxModel
20+
from .model.state_dict_adapter import FluxStateDictAdapter
2021
from .validate import build_flux_validator
2122

2223
__all__ = [
@@ -119,5 +120,6 @@
119120
build_tokenizer_fn=None,
120121
build_loss_fn=build_mse_loss,
121122
build_validator_fn=build_flux_validator,
123+
state_dict_adapter=FluxStateDictAdapter,
122124
)
123125
)
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import logging
9+
import os
10+
import re
11+
12+
from collections import defaultdict
13+
from typing import Any
14+
15+
import torch
16+
17+
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
18+
19+
from .args import FluxModelArgs
20+
21+
logger = logging.getLogger()
22+
23+
24+
class FluxStateDictAdapter(BaseStateDictAdapter):
25+
"""
26+
State dict adapter for Flux model to convert between HuggingFace safetensors format
27+
and torchtitan DCP format.
28+
29+
This state dict adapter handles only the state dict of transformer from Flux HF model repo.
30+
"""
31+
32+
def __init__(self, model_args: FluxModelArgs, hf_assets_path: str | None):
33+
34+
# Build fqn to index mapping if hf_assets_path
35+
if hf_assets_path:
36+
# If directory is multimodal ensure that hf_assets_path is to the folder containing transformer's safetensors
37+
if os.path.exists(os.path.join(hf_assets_path, "model_index.json")):
38+
hf_assets_path = os.path.join(hf_assets_path, "transformers")
39+
40+
# Check if safetensors index file exists
41+
index_files = [
42+
"model.safetensors.index.json",
43+
"diffusion_pytorch_model.safetensors.index.json",
44+
]
45+
46+
hf_safetensors_indx = None
47+
for index_file in index_files:
48+
mapping_path = os.path.join(hf_assets_path, index_file)
49+
if os.path.exists(mapping_path):
50+
with open(mapping_path, "r") as f:
51+
hf_safetensors_indx = json.load(f)
52+
break
53+
if hf_safetensors_indx is None:
54+
logger.warning(
55+
f"no safetensors index file found at hf_assets_path: {hf_assets_path}. \
56+
Defaulting to saving a single safetensors file if checkpoint is saved in HF format.",
57+
)
58+
59+
if hf_safetensors_indx:
60+
self.fqn_to_index_mapping = {}
61+
for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items():
62+
indx = re.search(r"\d+", raw_indx).group(0)
63+
self.fqn_to_index_mapping[hf_key] = indx
64+
else:
65+
self.fqn_to_index_mapping = None
66+
67+
self.model_args = model_args
68+
self.hf_assets_path = hf_assets_path
69+
70+
# mapping containing direct 1 to 1 mappings from HF to torchtitan
71+
self.from_hf_map_direct = {
72+
"x_embedder.bias": "img_in.bias",
73+
"x_embedder.weight": "img_in.weight",
74+
"context_embedder.bias": "txt_in.bias",
75+
"context_embedder.weight": "txt_in.weight",
76+
"norm_out.linear.bias": "final_layer.adaLN_modulation.1.bias",
77+
"norm_out.linear.weight": "final_layer.adaLN_modulation.1.weight",
78+
"proj_out.bias": "final_layer.linear.bias",
79+
"proj_out.weight": "final_layer.linear.weight",
80+
"time_text_embed.text_embedder.linear_1.bias": "vector_in.in_layer.bias",
81+
"time_text_embed.text_embedder.linear_1.weight": "vector_in.in_layer.weight",
82+
"time_text_embed.timestep_embedder.linear_1.bias": "time_in.in_layer.bias",
83+
"time_text_embed.timestep_embedder.linear_1.weight": "time_in.in_layer.weight",
84+
"time_text_embed.text_embedder.linear_2.bias": "vector_in.out_layer.bias",
85+
"time_text_embed.text_embedder.linear_2.weight": "vector_in.out_layer.weight",
86+
"time_text_embed.timestep_embedder.linear_2.bias": "time_in.out_layer.bias",
87+
"time_text_embed.timestep_embedder.linear_2.weight": "time_in.out_layer.weight",
88+
"single_transformer_blocks.{}.attn.norm_k.weight": "single_blocks.{}.norm.key_norm.weight",
89+
"single_transformer_blocks.{}.attn.norm_q.weight": "single_blocks.{}.norm.query_norm.weight",
90+
"single_transformer_blocks.{}.norm.linear.bias": "single_blocks.{}.modulation.lin.bias",
91+
"single_transformer_blocks.{}.norm.linear.weight": "single_blocks.{}.modulation.lin.weight",
92+
"single_transformer_blocks.{}.proj_out.bias": "single_blocks.{}.linear2.bias",
93+
"single_transformer_blocks.{}.proj_out.weight": "single_blocks.{}.linear2.weight",
94+
"transformer_blocks.{}.attn.norm_added_k.weight": "double_blocks.{}.txt_attn.norm.key_norm.weight",
95+
"transformer_blocks.{}.attn.norm_added_q.weight": "double_blocks.{}.txt_attn.norm.query_norm.weight",
96+
"transformer_blocks.{}.attn.norm_k.weight": "double_blocks.{}.img_attn.norm.key_norm.weight",
97+
"transformer_blocks.{}.attn.norm_q.weight": "double_blocks.{}.img_attn.norm.query_norm.weight",
98+
"transformer_blocks.{}.attn.to_add_out.bias": "double_blocks.{}.txt_attn.proj.bias",
99+
"transformer_blocks.{}.attn.to_add_out.weight": "double_blocks.{}.txt_attn.proj.weight",
100+
"transformer_blocks.{}.attn.to_out.0.bias": "double_blocks.{}.img_attn.proj.bias",
101+
"transformer_blocks.{}.attn.to_out.0.weight": "double_blocks.{}.img_attn.proj.weight",
102+
"transformer_blocks.{}.ff.net.0.proj.bias": "double_blocks.{}.img_mlp.0.bias",
103+
"transformer_blocks.{}.ff.net.0.proj.weight": "double_blocks.{}.img_mlp.0.weight",
104+
"transformer_blocks.{}.ff.net.2.bias": "double_blocks.{}.img_mlp.2.bias",
105+
"transformer_blocks.{}.ff.net.2.weight": "double_blocks.{}.img_mlp.2.weight",
106+
"transformer_blocks.{}.ff_context.net.0.proj.bias": "double_blocks.{}.txt_mlp.0.bias",
107+
"transformer_blocks.{}.ff_context.net.0.proj.weight": "double_blocks.{}.txt_mlp.0.weight",
108+
"transformer_blocks.{}.ff_context.net.2.bias": "double_blocks.{}.txt_mlp.2.bias",
109+
"transformer_blocks.{}.ff_context.net.2.weight": "double_blocks.{}.txt_mlp.2.weight",
110+
"transformer_blocks.{}.norm1.linear.bias": "double_blocks.{}.img_mod.lin.bias",
111+
"transformer_blocks.{}.norm1.linear.weight": "double_blocks.{}.img_mod.lin.weight",
112+
"transformer_blocks.{}.norm1_context.linear.bias": "double_blocks.{}.txt_mod.lin.bias",
113+
"transformer_blocks.{}.norm1_context.linear.weight": "double_blocks.{}.txt_mod.lin.weight",
114+
}
115+
116+
# combination plan to keep track of the order of layers to be combined
117+
self.combination_plan = {
118+
"single_blocks.{}.linear1.bias": [
119+
"single_transformer_blocks.{}.attn.to_q.bias",
120+
"single_transformer_blocks.{}.attn.to_k.bias",
121+
"single_transformer_blocks.{}.attn.to_v.bias",
122+
"single_transformer_blocks.{}.proj_mlp.bias",
123+
],
124+
"single_blocks.{}.linear1.weight": [
125+
"single_transformer_blocks.{}.attn.to_q.weight",
126+
"single_transformer_blocks.{}.attn.to_k.weight",
127+
"single_transformer_blocks.{}.attn.to_v.weight",
128+
"single_transformer_blocks.{}.proj_mlp.weight",
129+
],
130+
"double_blocks.{}.txt_attn.qkv.bias": [
131+
"transformer_blocks.{}.attn.add_q_proj.bias",
132+
"transformer_blocks.{}.attn.add_k_proj.bias",
133+
"transformer_blocks.{}.attn.add_v_proj.bias",
134+
],
135+
"double_blocks.{}.txt_attn.qkv.weight": [
136+
"transformer_blocks.{}.attn.add_q_proj.weight",
137+
"transformer_blocks.{}.attn.add_k_proj.weight",
138+
"transformer_blocks.{}.attn.add_v_proj.weight",
139+
],
140+
"double_blocks.{}.img_attn.qkv.bias": [
141+
"transformer_blocks.{}.attn.to_q.bias",
142+
"transformer_blocks.{}.attn.to_k.bias",
143+
"transformer_blocks.{}.attn.to_v.bias",
144+
],
145+
"double_blocks.{}.img_attn.qkv.weight": [
146+
"transformer_blocks.{}.attn.to_q.weight",
147+
"transformer_blocks.{}.attn.to_k.weight",
148+
"transformer_blocks.{}.attn.to_v.weight",
149+
],
150+
}
151+
152+
# reverse of combination plan: maps fqns to the fqn they are combined into
153+
self.reverse_combination_plan = {
154+
value: key
155+
for key, value_list in self.combination_plan.items()
156+
for value in value_list
157+
}
158+
159+
# original flux implementation and HF swap shift and scale
160+
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_flux_to_diffusers.py#L63-L68
161+
def _swap_scale_shift(self, weight):
162+
shift, scale = weight.chunk(2, dim=0)
163+
new_weight = torch.cat([scale, shift], dim=0)
164+
return new_weight
165+
166+
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
167+
"""Convert TorchTitan DCP state dict to HuggingFace safetensors format."""
168+
169+
to_hf_map_direct = {
170+
v: k for k, v in self.from_hf_map_direct.items() if v is not None
171+
}
172+
hf_state_dict = {}
173+
174+
for key, value in state_dict.items():
175+
# Extract layer_num and abstract key if necessary
176+
if "blocks" in key:
177+
layer_num = re.search(r"\d+", key).group(0)
178+
key = re.sub(r"(\d+)", "{}", key, count=1)
179+
else:
180+
layer_num = None
181+
182+
if key in to_hf_map_direct:
183+
# handle direct mapping
184+
new_key = to_hf_map_direct[key]
185+
186+
# perform swap to be compatible with HF
187+
if key in [
188+
"final_layer.adaLN_modulation.1.weight",
189+
"final_layer.adaLN_modulation.1.bias",
190+
]:
191+
value = self._swap_scale_shift(value)
192+
193+
if new_key is None:
194+
continue
195+
if layer_num:
196+
new_key = new_key.format(layer_num)
197+
198+
hf_state_dict[new_key] = value
199+
200+
elif key in self.combination_plan:
201+
# handle splitting layers
202+
if key in [
203+
"single_blocks.{}.linear1.bias",
204+
"single_blocks.{}.linear1.weight",
205+
]:
206+
mlp_hidden_dim = int(
207+
self.model_args.hidden_size * self.model_args.mlp_ratio
208+
)
209+
split_plan = [
210+
self.model_args.hidden_size,
211+
self.model_args.hidden_size,
212+
self.model_args.hidden_size,
213+
mlp_hidden_dim,
214+
]
215+
# split into q, k, v, mlp
216+
split_vals = torch.split(
217+
value,
218+
split_plan,
219+
dim=0,
220+
)
221+
else:
222+
# split into q, k, v
223+
split_vals = torch.split(value, self.model_args.hidden_size, dim=0)
224+
225+
new_keys = (
226+
abstract_key.format(layer_num)
227+
for abstract_key in self.combination_plan[key]
228+
)
229+
230+
for new_key, value in zip(new_keys, split_vals):
231+
hf_state_dict[new_key] = value
232+
233+
return hf_state_dict
234+
235+
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
236+
"""Convert HuggingFace safetensors state dict to TorchTitan DCP format."""
237+
state_dict = {}
238+
239+
# Keeps track of HF fqn values to combine into one TT fqn later
240+
# {tt_fqn : {hf_fqn1 : value}, {hf_fqn2 : value}, ...}
241+
to_combine = defaultdict(dict)
242+
243+
for key, value in hf_state_dict.items():
244+
# extract layer_num and abstract key if necessary
245+
if "blocks" in key:
246+
layer_num = re.search(r"\d+", key).group(0)
247+
key = re.sub(r"(\d+)", "{}", key, count=1)
248+
else:
249+
layer_num = None
250+
251+
if key in self.from_hf_map_direct:
252+
new_key = self.from_hf_map_direct[key]
253+
254+
# perform swap to be compatible with HF
255+
if key in [
256+
"norm_out.linear.weight",
257+
"norm_out.linear.bias",
258+
]:
259+
value = self._swap_scale_shift(value)
260+
if new_key is None:
261+
continue
262+
if layer_num:
263+
new_key = new_key.format(layer_num)
264+
265+
state_dict[new_key] = value
266+
elif key in self.reverse_combination_plan:
267+
# collect the layers that need to be combined
268+
tt_abstract_key = self.reverse_combination_plan[key]
269+
if tt_abstract_key is None:
270+
continue
271+
to_combine[tt_abstract_key.format(layer_num)][
272+
key.format(layer_num)
273+
] = value
274+
275+
# combine collected values
276+
for tt_fqn, hf_fqn_map in to_combine.items():
277+
layer_num = re.search(r"\d+", tt_fqn).group(0)
278+
tt_abstract_key = re.sub(r"(\d+)", "{}", tt_fqn, count=1)
279+
combine_values = []
280+
# use combination_plan to ensure correct order before concatenation
281+
for hf_abstract_key in self.combination_plan[tt_abstract_key]:
282+
hf_key = hf_abstract_key.format(layer_num)
283+
combine_values.append(hf_fqn_map[hf_key])
284+
285+
value = torch.cat(combine_values, dim=0)
286+
state_dict[tt_fqn] = value
287+
288+
return state_dict

torchtitan/experiments/flux/validate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ def validate(
104104
model = model_parts[0]
105105
model.eval()
106106

107+
# Disable cfg dropout during validation
108+
training_cfg_prob = self.job_config.training.classifier_free_guidance_prob
109+
self.job_config.training.classifier_free_guidance_prob = 0.0
110+
107111
save_img_count = self.job_config.validation.save_img_count
108112

109113
parallel_dims = self.parallel_dims
@@ -244,6 +248,9 @@ def validate(
244248
# Set model back to train mode
245249
model.train()
246250

251+
# re-enable cfg dropout for training
252+
self.job_config.training.classifier_free_guidance_prob = training_cfg_prob
253+
247254

248255
def build_flux_validator(
249256
job_config: JobConfig,

torchtitan/protocols/state_dict_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None):
6666
hf_safetensors_indx = json.load(f)
6767
except FileNotFoundError:
6868
logger.warning(
69-
"model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \
70-
Defaulting to saving a single safetensors file if checkpoint is saved in HF format.",
69+
f"model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \
70+
Defaulting to saving a single safetensors file if checkpoint is saved in HF format."
7171
)
7272
hf_safetensors_indx = None
7373

0 commit comments

Comments
 (0)