@@ -30,43 +30,35 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
3030 const auto & model_args = context.get_model_args (" vae" );
3131 options_ = context.get_tensor_options ();
3232 vae_scale_factor_ = 1 << (model_args.block_out_channels ().size () - 1 );
33- device_ = options_.device ();
34- dtype_ = options_.dtype ().toScalarType ();
3533
3634 vae_shift_factor_ = model_args.shift_factor ();
3735 vae_scaling_factor_ = model_args.scale_factor ();
38- default_sample_size_ = 128 ;
39- tokenizer_max_length_ = 77 ; // TODO: get from config file
36+ tokenizer_max_length_ =
37+ context. get_model_args ( " text_encoder " ). max_position_embeddings ();
4038 LOG (INFO) << " Initializing Flux pipeline..." ;
41- vae_image_processor_ = VAEImageProcessor (
42- context.get_model_context (" vae" ), true , true , false , false , false );
39+ vae_image_processor_ = VAEImageProcessor (context.get_model_context (" vae" ),
40+ true ,
41+ true ,
42+ false ,
43+ false ,
44+ false ,
45+ model_args.latent_channels ());
4346 vae_ = VAE (context.get_model_context (" vae" ));
44- LOG (INFO) << " VAE initialized." ;
4547 pos_embed_ = register_module (
4648 " pos_embed" ,
47- FluxPosEmbed (10000 ,
49+ FluxPosEmbed (ROPE_SCALE_BASE ,
4850 context.get_model_args (" transformer" ).axes_dims_rope ()));
4951 transformer_ = FluxDiTModel (context.get_model_context (" transformer" ));
50- LOG (INFO) << " DiT transformer initialized." ;
5152 t5_ = T5EncoderModel (context.get_model_context (" text_encoder_2" ));
52- LOG (INFO) << " T5 initialized." ;
5353 clip_text_model_ = CLIPTextModel (context.get_model_context (" text_encoder" ));
54- LOG (INFO) << " CLIP text model initialized." ;
5554 scheduler_ =
5655 FlowMatchEulerDiscreteScheduler (context.get_model_context (" scheduler" ));
57- LOG (INFO) << " Flux pipeline initialized." ;
5856 register_module (" vae" , vae_);
59- LOG (INFO) << " VAE registered." ;
6057 register_module (" vae_image_processor" , vae_image_processor_);
61- LOG (INFO) << " VAE image processor registered." ;
6258 register_module (" transformer" , transformer_);
63- LOG (INFO) << " DiT transformer registered." ;
6459 register_module (" t5" , t5_);
65- LOG (INFO) << " T5 registered." ;
6660 register_module (" scheduler" , scheduler_);
67- LOG (INFO) << " Scheduler registered." ;
6861 register_module (" clip_text_model" , clip_text_model_);
69- LOG (INFO) << " CLIP text model registered." ;
7062 }
7163
7264 DiTForwardOutput forward (const DiTForwardInput& input) {
@@ -104,21 +96,21 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
10496 : std::nullopt ;
10597
10698 std::vector<torch::Tensor> output = forward_ (
107- prompts, // prompt
108- prompts_2, // prompt_2
109- negative_prompts, // negative_prompt
110- negative_prompts_2, // negative_prompt_2
111- generation_params.true_cfg_scale , // cfg scale
112- std::make_optional ( generation_params.height ), // height
113- std::make_optional ( generation_params.width ), // width
114- generation_params.num_inference_steps , // num_inference_steps
115- generation_params.guidance_scale , // guidance_scale
116- generation_params.num_images_per_prompt , // num_images_per_prompt
117- seed, // seed
118- latents, // latents
119- prompt_embeds, // prompt_embeds
120- negative_prompt_embeds, // negative_prompt_embeds
121- pooled_prompt_embeds, // pooled_prompt_embeds
99+ prompts, // prompt
100+ prompts_2, // prompt_2
101+ negative_prompts, // negative_prompt
102+ negative_prompts_2, // negative_prompt_2
103+ generation_params.true_cfg_scale , // cfg scale
104+ generation_params.height , // height
105+ generation_params.width , // width
106+ generation_params.num_inference_steps , // num_inference_steps
107+ generation_params.guidance_scale , // guidance_scale
108+ generation_params.num_images_per_prompt , // num_images_per_prompt
109+ seed, // seed
110+ latents, // latents
111+ prompt_embeds, // prompt_embeds
112+ negative_prompt_embeds, // negative_prompt_embeds
113+ pooled_prompt_embeds, // pooled_prompt_embeds
122114 negative_pooled_prompt_embeds, // negative_pooled_prompt_embeds
123115 generation_params.max_sequence_length // max_sequence_length
124116 );
@@ -141,13 +133,13 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
141133 LOG (INFO)
142134 << " Flux model components loaded, start to load weights to sub models" ;
143135 transformer_->load_model (std::move (transformer_loader));
144- transformer_->to (device_ );
136+ transformer_->to (options_. device () );
145137 vae_->load_model (std::move (vae_loader));
146- vae_->to (device_ );
138+ vae_->to (options_. device () );
147139 t5_->load_model (std::move (t5_loader));
148- t5_->to (device_ );
140+ t5_->to (options_. device () );
149141 clip_text_model_->load_model (std::move (clip_loader));
150- clip_text_model_->to (device_ );
142+ clip_text_model_->to (options_. device () );
151143 tokenizer_ = tokenizer_loader->tokenizer ();
152144 tokenizer_2_ = tokenizer_2_loader->tokenizer ();
153145 }
@@ -186,8 +178,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
186178 std::optional<std::vector<std::string>> negative_prompt = std::nullopt ,
187179 std::optional<std::vector<std::string>> negative_prompt_2 = std::nullopt ,
188180 float true_cfg_scale = 1 .0f ,
189- std::optional< int64_t > height = std:: nullopt ,
190- std::optional< int64_t > width = std:: nullopt ,
181+ int64_t height = 512 ,
182+ int64_t width = 512 ,
191183 int64_t num_inference_steps = 28 ,
192184 float guidance_scale = 3 .5f ,
193185 int64_t num_images_per_prompt = 1 ,
@@ -199,12 +191,6 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
199191 std::optional<torch::Tensor> negative_pooled_prompt_embeds = std::nullopt ,
200192 int64_t max_sequence_length = 512 ) {
201193 torch::NoGradGuard no_grad;
202- int64_t actual_height = height.has_value ()
203- ? height.value ()
204- : default_sample_size_ * vae_scale_factor_;
205- int64_t actual_width = width.has_value ()
206- ? width.value ()
207- : default_sample_size_ * vae_scale_factor_;
208194 int64_t batch_size;
209195 if (prompt.has_value ()) {
210196 batch_size = prompt.value ().size ();
@@ -244,8 +230,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
244230 auto [prepared_latents, latent_image_ids] =
245231 prepare_latents (total_batch_size,
246232 num_channels_latents,
247- actual_height ,
248- actual_width ,
233+ height ,
234+ width ,
249235 seed.has_value () ? seed.value () : 42 ,
250236 latents);
251237 // prepare timestep
@@ -263,7 +249,7 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
263249 scheduler_->base_shift (),
264250 scheduler_->max_shift ());
265251 auto [timesteps, num_inference_steps_actual] = retrieve_timesteps (
266- scheduler_, num_inference_steps, device_ , new_sigmas, mu);
252+ scheduler_, num_inference_steps, options_. device () , new_sigmas, mu);
267253 int64_t num_warmup_steps =
268254 std::max (static_cast <int64_t >(timesteps.numel ()) -
269255 num_inference_steps_actual * scheduler_->order (),
@@ -272,7 +258,7 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
272258 torch::Tensor guidance;
273259 if (transformer_->guidance_embeds ()) {
274260 torch::TensorOptions options =
275- torch::dtype (torch::kFloat32 ).device (device_ );
261+ torch::dtype (torch::kFloat32 ).device (options_. device () );
276262
277263 guidance = torch::full (at::IntArrayRef ({1 }), guidance_scale, options);
278264 guidance = guidance.expand ({prepared_latents.size (0 )});
@@ -284,8 +270,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
284270 auto [rot_emb1, rot_emb2] =
285271 pos_embed_->forward_cache (text_ids,
286272 latent_image_ids,
287- height. value () / (vae_scale_factor_ * 2 ),
288- width. value () / (vae_scale_factor_ * 2 ));
273+ height / (vae_scale_factor_ * 2 ),
274+ width / (vae_scale_factor_ * 2 ));
289275 torch::Tensor image_rotary_emb = torch::stack ({rot_emb1, rot_emb2}, 0 );
290276 for (int64_t i = 0 ; i < timesteps.numel (); ++i) {
291277 torch::Tensor t = timesteps[i].unsqueeze (0 );
@@ -326,13 +312,13 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
326312 }
327313 torch::Tensor image;
328314 // Unpack latents
329- torch::Tensor unpacked_latents = unpack_latents (
330- prepared_latents, actual_height, actual_width , vae_scale_factor_);
315+ torch::Tensor unpacked_latents =
316+ unpack_latents ( prepared_latents, height, width , vae_scale_factor_);
331317 unpacked_latents =
332318 (unpacked_latents / vae_scaling_factor_) + vae_shift_factor_;
333- unpacked_latents = unpacked_latents.to (dtype_ );
319+ unpacked_latents = unpacked_latents.to (options_. dtype () );
334320 image = vae_->decode (unpacked_latents);
335- image = vae_image_processor_->postprocess (image, " pil " );
321+ image = vae_image_processor_->postprocess (image);
336322 return std::vector<torch::Tensor>{{image}};
337323 }
338324
@@ -343,7 +329,6 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
343329 FluxDiTModel transformer_{nullptr };
344330 float vae_scaling_factor_;
345331 float vae_shift_factor_;
346- int default_sample_size_;
347332 FluxPosEmbed pos_embed_{nullptr };
348333};
349334TORCH_MODULE (FluxPipeline);
0 commit comments