1+ import argparse
2+ import json
3+ import os
4+
5+ import torch
6+ import torch .distributed as dist
7+
8+ from fastvideo .v1 .logger import init_logger
9+ from fastvideo .v1 .utils import maybe_download_model , shallow_asdict
10+ from fastvideo .v1 .distributed import init_distributed_environment , initialize_model_parallel
11+ from fastvideo .v1 .fastvideo_args import FastVideoArgs
12+ from fastvideo .v1 .configs .models .vaes import WanVAEConfig
13+ from fastvideo import PipelineConfig
14+ from fastvideo .v1 .pipelines .preprocess_pipeline import PreprocessPipeline
15+
16+ logger = init_logger (__name__ )
17+
18+ def main (args ):
19+ args .model_path = maybe_download_model (args .model_path )
20+ # Assume using torchrun
21+ local_rank = int (os .getenv ("RANK" , 0 ))
22+ rank = int (os .environ .get ("RANK" , 0 ))
23+ world_size = int (os .getenv ("WORLD_SIZE" , 1 ))
24+ init_distributed_environment (world_size = world_size , rank = rank , local_rank = local_rank )
25+ initialize_model_parallel (tensor_model_parallel_size = world_size , sequence_model_parallel_size = world_size )
26+ torch .cuda .set_device (local_rank )
27+ if not dist .is_initialized ():
28+ dist .init_process_group (backend = "nccl" , init_method = "env://" , world_size = world_size , rank = local_rank )
29+
30+ pipeline_config = PipelineConfig .from_pretrained (args .model_path )
31+ kwargs = {
32+ "use_cpu_offload" : False ,
33+ "vae_precision" : "fp32" ,
34+ "vae_config" : WanVAEConfig (load_encoder = True , load_decoder = False ),
35+ }
36+ pipeline_config_args = shallow_asdict (pipeline_config )
37+ pipeline_config_args .update (kwargs )
38+ fastvideo_args = FastVideoArgs (model_path = args .model_path ,
39+ num_gpus = world_size ,
40+ device_str = "cuda" ,
41+ ** pipeline_config_args ,
42+ )
43+ fastvideo_args .check_fastvideo_args ()
44+ fastvideo_args .device = torch .device (f"cuda:{ local_rank } " )
45+
46+ pipeline = PreprocessPipeline (args .model_path , fastvideo_args )
47+ pipeline .forward (batch = None , fastvideo_args = fastvideo_args , args = args )
48+
49+
50+ if __name__ == "__main__" :
51+ parser = argparse .ArgumentParser ()
52+ # dataset & dataloader
53+ parser .add_argument ("--model_path" , type = str , default = "data/mochi" )
54+ parser .add_argument ("--model_type" , type = str , default = "mochi" )
55+ parser .add_argument ("--data_merge_path" , type = str , required = True )
56+ parser .add_argument ("--validation_prompt_txt" , type = str )
57+ parser .add_argument ("--num_frames" , type = int , default = 163 )
58+ parser .add_argument (
59+ "--dataloader_num_workers" ,
60+ type = int ,
61+ default = 1 ,
62+ help = "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ,
63+ )
64+ parser .add_argument (
65+ "--preprocess_video_batch_size" ,
66+ type = int ,
67+ default = 2 ,
68+ help = "Batch size (per device) for the training dataloader." ,
69+ )
70+ parser .add_argument (
71+ "--preprocess_text_batch_size" ,
72+ type = int ,
73+ default = 8 ,
74+ help = "Batch size (per device) for the training dataloader." ,
75+ )
76+ parser .add_argument (
77+ "--samples_per_file" ,
78+ type = int ,
79+ default = 64
80+ )
81+ parser .add_argument (
82+ "--flush_frequency" ,
83+ type = int ,
84+ default = 256 ,
85+ help = "how often to save to parquet files"
86+ )
87+ parser .add_argument ("--num_latent_t" , type = int , default = 28 , help = "Number of latent timesteps." )
88+ parser .add_argument ("--max_height" , type = int , default = 480 )
89+ parser .add_argument ("--max_width" , type = int , default = 848 )
90+ parser .add_argument ("--video_length_tolerance_range" , type = int , default = 2.0 )
91+ parser .add_argument ("--group_frame" , action = "store_true" ) # TODO
92+ parser .add_argument ("--group_resolution" , action = "store_true" ) # TODO
93+ parser .add_argument ("--dataset" , default = "t2v" )
94+ parser .add_argument ("--train_fps" , type = int , default = 30 )
95+ parser .add_argument ("--use_image_num" , type = int , default = 0 )
96+ parser .add_argument ("--text_max_length" , type = int , default = 256 )
97+ parser .add_argument ("--speed_factor" , type = float , default = 1.0 )
98+ parser .add_argument ("--drop_short_ratio" , type = float , default = 1.0 )
99+ # text encoder & vae & diffusion model
100+ parser .add_argument ("--text_encoder_name" , type = str , default = "google/t5-v1_1-xxl" )
101+ parser .add_argument ("--cache_dir" , type = str , default = "./cache_dir" )
102+ parser .add_argument ("--cfg" , type = float , default = 0.0 )
103+ parser .add_argument (
104+ "--output_dir" ,
105+ type = str ,
106+ default = None ,
107+ help = "The output directory where the model predictions and checkpoints will be written." ,
108+ )
109+ parser .add_argument (
110+ "--logging_dir" ,
111+ type = str ,
112+ default = "logs" ,
113+ help = ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
114+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ),
115+ )
116+
117+ args = parser .parse_args ()
118+ main (args )
0 commit comments