Skip to content

Commit 1c25438

Browse files
authored
feat: implement FLUX.1-Control model. (jd-opensource#319)
1 parent ef0c5fe commit 1c25438

File tree

12 files changed

+486
-127
lines changed

12 files changed

+486
-127
lines changed

xllm/core/framework/batch/dit_batch.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
6262

6363
std::vector<torch::Tensor> images;
6464
std::vector<torch::Tensor> mask_images;
65-
65+
std::vector<torch::Tensor> control_images;
6666
std::vector<torch::Tensor> latents;
6767
std::vector<torch::Tensor> masked_image_latents;
6868
for (const auto& request : request_vec_) {
@@ -96,6 +96,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
9696

9797
images.emplace_back(input_params.image);
9898
mask_images.emplace_back(input_params.mask_image);
99+
control_images.emplace_back(input_params.control_image);
99100
}
100101

101102
if (input.prompts.size() != request_vec_.size()) {
@@ -122,6 +123,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
122123
input.mask_images = torch::stack(mask_images);
123124
}
124125

126+
if (check_tensors_valid(control_images)) {
127+
input.control_image = torch::stack(control_images);
128+
}
129+
125130
if (check_tensors_valid(prompt_embeds)) {
126131
input.prompt_embeds = torch::stack(prompt_embeds);
127132
}

xllm/core/framework/request/dit_request_params.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,16 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
270270
}
271271
}
272272

273+
if (input.has_control_image()) {
274+
std::string raw_bytes;
275+
if (!butil::Base64Decode(input.control_image(), &raw_bytes)) {
276+
LOG(ERROR) << "Base64 control_image decode failed";
277+
}
278+
if (!decoder.decode(raw_bytes, input_params.control_image)) {
279+
LOG(ERROR) << "Control_image decode failed.";
280+
}
281+
}
282+
273283
// generation params
274284
const auto& params = request.parameters();
275285
if (params.has_size()) {

xllm/core/framework/request/dit_request_state.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ struct DiTInputParams {
9292

9393
torch::Tensor image;
9494

95+
torch::Tensor control_image;
96+
9597
torch::Tensor mask_image;
9698

9799
torch::Tensor masked_image_latent;

xllm/core/runtime/dit_forward_params.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ struct DiTForwardInput {
8484

8585
torch::Tensor mask_images;
8686

87+
torch::Tensor control_image;
88+
8789
torch::Tensor masked_image_latents;
8890

8991
torch::Tensor prompt_embeds;

xllm/models/dit/autoencoder_kl.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@ class VAEImageProcessorImpl : public torch::nn::Module {
6262
bool do_normalize = true,
6363
bool do_binarize = false,
6464
bool do_convert_rgb = false,
65-
bool do_convert_grayscale = false) {
65+
bool do_convert_grayscale = false,
66+
int64_t latent_channels = 4) {
6667
const auto& model_args = context.get_model_args();
68+
options_ = context.get_tensor_options();
6769
scale_factor_ = 1 << model_args.block_out_channels().size();
68-
latent_channels_ = 4;
70+
latent_channels_ = latent_channels;
6971
do_resize_ = do_resize;
7072
do_normalize_ = do_normalize;
7173
do_binarize_ = do_binarize;
@@ -116,7 +118,6 @@ class VAEImageProcessorImpl : public torch::nn::Module {
116118
if (channel == latent_channels_) {
117119
return image;
118120
}
119-
120121
auto [target_h, target_w] =
121122
get_default_height_width(processed, height, width);
122123
if (do_resize_) {
@@ -129,13 +130,12 @@ class VAEImageProcessorImpl : public torch::nn::Module {
129130
if (do_binarize_) {
130131
processed = (processed >= 0.5f).to(torch::kFloat32);
131132
}
132-
processed = processed.to(image.dtype());
133+
processed = processed.to(options_);
133134
return processed;
134135
}
135136

136137
torch::Tensor postprocess(
137138
const torch::Tensor& tensor,
138-
const std::string& output_type = "pt",
139139
std::optional<std::vector<bool>> do_denormalize = std::nullopt) {
140140
torch::Tensor processed = tensor.clone();
141141
if (do_normalize_) {
@@ -149,9 +149,6 @@ class VAEImageProcessorImpl : public torch::nn::Module {
149149
}
150150
}
151151
}
152-
if (output_type == "np") {
153-
return processed.permute({0, 2, 3, 1}).contiguous();
154-
}
155152
return processed;
156153
}
157154

@@ -202,6 +199,7 @@ class VAEImageProcessorImpl : public torch::nn::Module {
202199
bool do_binarize_ = false;
203200
bool do_convert_rgb_ = false;
204201
bool do_convert_grayscale_ = false;
202+
torch::TensorOptions options_;
205203
};
206204
TORCH_MODULE(VAEImageProcessor);
207205

xllm/models/dit/dit.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,7 @@ class FluxTransformer2DModelImpl : public torch::nn::Module {
14361436
proj_out_->verify_loaded_weights(prefix + "proj_out.");
14371437
}
14381438

1439-
int64_t in_channels() { return out_channels_; }
1439+
int64_t in_channels() { return in_channels_; }
14401440
bool guidance_embeds() { return guidance_embeds_; }
14411441

14421442
private:

xllm/models/dit/pipeline_flux.h

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,43 +30,35 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
3030
const auto& model_args = context.get_model_args("vae");
3131
options_ = context.get_tensor_options();
3232
vae_scale_factor_ = 1 << (model_args.block_out_channels().size() - 1);
33-
device_ = options_.device();
34-
dtype_ = options_.dtype().toScalarType();
3533

3634
vae_shift_factor_ = model_args.shift_factor();
3735
vae_scaling_factor_ = model_args.scale_factor();
38-
default_sample_size_ = 128;
39-
tokenizer_max_length_ = 77; // TODO: get from config file
36+
tokenizer_max_length_ =
37+
context.get_model_args("text_encoder").max_position_embeddings();
4038
LOG(INFO) << "Initializing Flux pipeline...";
41-
vae_image_processor_ = VAEImageProcessor(
42-
context.get_model_context("vae"), true, true, false, false, false);
39+
vae_image_processor_ = VAEImageProcessor(context.get_model_context("vae"),
40+
true,
41+
true,
42+
false,
43+
false,
44+
false,
45+
model_args.latent_channels());
4346
vae_ = VAE(context.get_model_context("vae"));
44-
LOG(INFO) << "VAE initialized.";
4547
pos_embed_ = register_module(
4648
"pos_embed",
47-
FluxPosEmbed(10000,
49+
FluxPosEmbed(ROPE_SCALE_BASE,
4850
context.get_model_args("transformer").axes_dims_rope()));
4951
transformer_ = FluxDiTModel(context.get_model_context("transformer"));
50-
LOG(INFO) << "DiT transformer initialized.";
5152
t5_ = T5EncoderModel(context.get_model_context("text_encoder_2"));
52-
LOG(INFO) << "T5 initialized.";
5353
clip_text_model_ = CLIPTextModel(context.get_model_context("text_encoder"));
54-
LOG(INFO) << "CLIP text model initialized.";
5554
scheduler_ =
5655
FlowMatchEulerDiscreteScheduler(context.get_model_context("scheduler"));
57-
LOG(INFO) << "Flux pipeline initialized.";
5856
register_module("vae", vae_);
59-
LOG(INFO) << "VAE registered.";
6057
register_module("vae_image_processor", vae_image_processor_);
61-
LOG(INFO) << "VAE image processor registered.";
6258
register_module("transformer", transformer_);
63-
LOG(INFO) << "DiT transformer registered.";
6459
register_module("t5", t5_);
65-
LOG(INFO) << "T5 registered.";
6660
register_module("scheduler", scheduler_);
67-
LOG(INFO) << "Scheduler registered.";
6861
register_module("clip_text_model", clip_text_model_);
69-
LOG(INFO) << "CLIP text model registered.";
7062
}
7163

7264
DiTForwardOutput forward(const DiTForwardInput& input) {
@@ -104,21 +96,21 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
10496
: std::nullopt;
10597

10698
std::vector<torch::Tensor> output = forward_(
107-
prompts, // prompt
108-
prompts_2, // prompt_2
109-
negative_prompts, // negative_prompt
110-
negative_prompts_2, // negative_prompt_2
111-
generation_params.true_cfg_scale, // cfg scale
112-
std::make_optional(generation_params.height), // height
113-
std::make_optional(generation_params.width), // width
114-
generation_params.num_inference_steps, // num_inference_steps
115-
generation_params.guidance_scale, // guidance_scale
116-
generation_params.num_images_per_prompt, // num_images_per_prompt
117-
seed, // seed
118-
latents, // latents
119-
prompt_embeds, // prompt_embeds
120-
negative_prompt_embeds, // negative_prompt_embeds
121-
pooled_prompt_embeds, // pooled_prompt_embeds
99+
prompts, // prompt
100+
prompts_2, // prompt_2
101+
negative_prompts, // negative_prompt
102+
negative_prompts_2, // negative_prompt_2
103+
generation_params.true_cfg_scale, // cfg scale
104+
generation_params.height, // height
105+
generation_params.width, // width
106+
generation_params.num_inference_steps, // num_inference_steps
107+
generation_params.guidance_scale, // guidance_scale
108+
generation_params.num_images_per_prompt, // num_images_per_prompt
109+
seed, // seed
110+
latents, // latents
111+
prompt_embeds, // prompt_embeds
112+
negative_prompt_embeds, // negative_prompt_embeds
113+
pooled_prompt_embeds, // pooled_prompt_embeds
122114
negative_pooled_prompt_embeds, // negative_pooled_prompt_embeds
123115
generation_params.max_sequence_length // max_sequence_length
124116
);
@@ -141,13 +133,13 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
141133
LOG(INFO)
142134
<< "Flux model components loaded, start to load weights to sub models";
143135
transformer_->load_model(std::move(transformer_loader));
144-
transformer_->to(device_);
136+
transformer_->to(options_.device());
145137
vae_->load_model(std::move(vae_loader));
146-
vae_->to(device_);
138+
vae_->to(options_.device());
147139
t5_->load_model(std::move(t5_loader));
148-
t5_->to(device_);
140+
t5_->to(options_.device());
149141
clip_text_model_->load_model(std::move(clip_loader));
150-
clip_text_model_->to(device_);
142+
clip_text_model_->to(options_.device());
151143
tokenizer_ = tokenizer_loader->tokenizer();
152144
tokenizer_2_ = tokenizer_2_loader->tokenizer();
153145
}
@@ -186,8 +178,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
186178
std::optional<std::vector<std::string>> negative_prompt = std::nullopt,
187179
std::optional<std::vector<std::string>> negative_prompt_2 = std::nullopt,
188180
float true_cfg_scale = 1.0f,
189-
std::optional<int64_t> height = std::nullopt,
190-
std::optional<int64_t> width = std::nullopt,
181+
int64_t height = 512,
182+
int64_t width = 512,
191183
int64_t num_inference_steps = 28,
192184
float guidance_scale = 3.5f,
193185
int64_t num_images_per_prompt = 1,
@@ -199,12 +191,6 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
199191
std::optional<torch::Tensor> negative_pooled_prompt_embeds = std::nullopt,
200192
int64_t max_sequence_length = 512) {
201193
torch::NoGradGuard no_grad;
202-
int64_t actual_height = height.has_value()
203-
? height.value()
204-
: default_sample_size_ * vae_scale_factor_;
205-
int64_t actual_width = width.has_value()
206-
? width.value()
207-
: default_sample_size_ * vae_scale_factor_;
208194
int64_t batch_size;
209195
if (prompt.has_value()) {
210196
batch_size = prompt.value().size();
@@ -244,8 +230,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
244230
auto [prepared_latents, latent_image_ids] =
245231
prepare_latents(total_batch_size,
246232
num_channels_latents,
247-
actual_height,
248-
actual_width,
233+
height,
234+
width,
249235
seed.has_value() ? seed.value() : 42,
250236
latents);
251237
// prepare timestep
@@ -263,7 +249,7 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
263249
scheduler_->base_shift(),
264250
scheduler_->max_shift());
265251
auto [timesteps, num_inference_steps_actual] = retrieve_timesteps(
266-
scheduler_, num_inference_steps, device_, new_sigmas, mu);
252+
scheduler_, num_inference_steps, options_.device(), new_sigmas, mu);
267253
int64_t num_warmup_steps =
268254
std::max(static_cast<int64_t>(timesteps.numel()) -
269255
num_inference_steps_actual * scheduler_->order(),
@@ -272,7 +258,7 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
272258
torch::Tensor guidance;
273259
if (transformer_->guidance_embeds()) {
274260
torch::TensorOptions options =
275-
torch::dtype(torch::kFloat32).device(device_);
261+
torch::dtype(torch::kFloat32).device(options_.device());
276262

277263
guidance = torch::full(at::IntArrayRef({1}), guidance_scale, options);
278264
guidance = guidance.expand({prepared_latents.size(0)});
@@ -284,8 +270,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
284270
auto [rot_emb1, rot_emb2] =
285271
pos_embed_->forward_cache(text_ids,
286272
latent_image_ids,
287-
height.value() / (vae_scale_factor_ * 2),
288-
width.value() / (vae_scale_factor_ * 2));
273+
height / (vae_scale_factor_ * 2),
274+
width / (vae_scale_factor_ * 2));
289275
torch::Tensor image_rotary_emb = torch::stack({rot_emb1, rot_emb2}, 0);
290276
for (int64_t i = 0; i < timesteps.numel(); ++i) {
291277
torch::Tensor t = timesteps[i].unsqueeze(0);
@@ -326,13 +312,13 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
326312
}
327313
torch::Tensor image;
328314
// Unpack latents
329-
torch::Tensor unpacked_latents = unpack_latents(
330-
prepared_latents, actual_height, actual_width, vae_scale_factor_);
315+
torch::Tensor unpacked_latents =
316+
unpack_latents(prepared_latents, height, width, vae_scale_factor_);
331317
unpacked_latents =
332318
(unpacked_latents / vae_scaling_factor_) + vae_shift_factor_;
333-
unpacked_latents = unpacked_latents.to(dtype_);
319+
unpacked_latents = unpacked_latents.to(options_.dtype());
334320
image = vae_->decode(unpacked_latents);
335-
image = vae_image_processor_->postprocess(image, "pil");
321+
image = vae_image_processor_->postprocess(image);
336322
return std::vector<torch::Tensor>{{image}};
337323
}
338324

@@ -343,7 +329,6 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
343329
FluxDiTModel transformer_{nullptr};
344330
float vae_scaling_factor_;
345331
float vae_shift_factor_;
346-
int default_sample_size_;
347332
FluxPosEmbed pos_embed_{nullptr};
348333
};
349334
TORCH_MODULE(FluxPipeline);

xllm/models/dit/pipeline_flux_base.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ limitations under the License.
3636

3737
namespace xllm {
3838

39+
constexpr int64_t ROPE_SCALE_BASE = 10000;
40+
3941
float calculate_shift(int64_t image_seq_len,
4042
int64_t base_seq_len = 256,
4143
int64_t max_seq_len = 4096,
@@ -213,9 +215,9 @@ class FluxPipelineBaseImpl : public torch::nn::Module {
213215
auto input_ids =
214216
torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong))
215217
.view({batch_size, max_sequence_length})
216-
.to(device_);
218+
.to(options_.device());
217219
torch::Tensor prompt_embeds = t5_->forward(input_ids);
218-
prompt_embeds = prompt_embeds.to(device_).to(dtype_);
220+
prompt_embeds = prompt_embeds.to(options_);
219221
int64_t seq_len = prompt_embeds.size(1);
220222
prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt, 1});
221223
prompt_embeds =
@@ -244,10 +246,10 @@ class FluxPipelineBaseImpl : public torch::nn::Module {
244246
auto input_ids =
245247
torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong))
246248
.view({batch_size, tokenizer_max_length_})
247-
.to(device_);
249+
.to(options_.device());
248250
auto encoder_output = clip_text_model_->forward(input_ids);
249251
torch::Tensor prompt_embeds = encoder_output;
250-
prompt_embeds = prompt_embeds.to(device_).to(dtype_);
252+
prompt_embeds = prompt_embeds.to(options_);
251253
prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt});
252254
prompt_embeds =
253255
prompt_embeds.view({batch_size * num_images_per_prompt, -1});
@@ -281,8 +283,8 @@ class FluxPipelineBaseImpl : public torch::nn::Module {
281283
prompt_embeds = get_t5_prompt_embeds(
282284
prompt_2_list, num_images_per_prompt, max_sequence_length);
283285
}
284-
torch::Tensor text_ids = torch::zeros({prompt_embeds.value().size(1), 3},
285-
torch::device(device_).dtype(dtype_));
286+
torch::Tensor text_ids =
287+
torch::zeros({prompt_embeds.value().size(1), 3}, options_);
286288

287289
return std::make_tuple(prompt_embeds.value(),
288290
pooled_prompt_embeds.has_value()

0 commit comments

Comments
 (0)