Skip to content

Commit f85eb1a

Browse files
[Training] [0/n] Add preprocessing pipeline (hao-ai-lab#442)
Co-authored-by: “BrianChen1129” <[email protected]>
1 parent 6f170a8 commit f85eb1a

File tree

5 files changed

+718
-6
lines changed

5 files changed

+718
-6
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import argparse
2+
import json
3+
import os
4+
5+
import torch
6+
import torch.distributed as dist
7+
8+
from fastvideo.v1.logger import init_logger
9+
from fastvideo.v1.utils import maybe_download_model, shallow_asdict
10+
from fastvideo.v1.distributed import init_distributed_environment, initialize_model_parallel
11+
from fastvideo.v1.fastvideo_args import FastVideoArgs
12+
from fastvideo.v1.configs.models.vaes import WanVAEConfig
13+
from fastvideo import PipelineConfig
14+
from fastvideo.v1.pipelines.preprocess_pipeline import PreprocessPipeline
15+
16+
logger = init_logger(__name__)
17+
18+
def main(args):
19+
args.model_path = maybe_download_model(args.model_path)
20+
# Assume using torchrun
21+
local_rank = int(os.getenv("RANK", 0))
22+
rank = int(os.environ.get("RANK", 0))
23+
world_size = int(os.getenv("WORLD_SIZE", 1))
24+
init_distributed_environment(world_size=world_size, rank=rank, local_rank=local_rank)
25+
initialize_model_parallel(tensor_model_parallel_size=world_size, sequence_model_parallel_size=world_size)
26+
torch.cuda.set_device(local_rank)
27+
if not dist.is_initialized():
28+
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
29+
30+
pipeline_config = PipelineConfig.from_pretrained(args.model_path)
31+
kwargs = {
32+
"use_cpu_offload": False,
33+
"vae_precision": "fp32",
34+
"vae_config": WanVAEConfig(load_encoder=True, load_decoder=False),
35+
}
36+
pipeline_config_args = shallow_asdict(pipeline_config)
37+
pipeline_config_args.update(kwargs)
38+
fastvideo_args = FastVideoArgs(model_path=args.model_path,
39+
num_gpus=world_size,
40+
device_str="cuda",
41+
**pipeline_config_args,
42+
)
43+
fastvideo_args.check_fastvideo_args()
44+
fastvideo_args.device = torch.device(f"cuda:{local_rank}")
45+
46+
pipeline = PreprocessPipeline(args.model_path, fastvideo_args)
47+
pipeline.forward(batch=None, fastvideo_args=fastvideo_args, args=args)
48+
49+
50+
if __name__ == "__main__":
51+
parser = argparse.ArgumentParser()
52+
# dataset & dataloader
53+
parser.add_argument("--model_path", type=str, default="data/mochi")
54+
parser.add_argument("--model_type", type=str, default="mochi")
55+
parser.add_argument("--data_merge_path", type=str, required=True)
56+
parser.add_argument("--validation_prompt_txt", type=str)
57+
parser.add_argument("--num_frames", type=int, default=163)
58+
parser.add_argument(
59+
"--dataloader_num_workers",
60+
type=int,
61+
default=1,
62+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
63+
)
64+
parser.add_argument(
65+
"--preprocess_video_batch_size",
66+
type=int,
67+
default=2,
68+
help="Batch size (per device) for the training dataloader.",
69+
)
70+
parser.add_argument(
71+
"--preprocess_text_batch_size",
72+
type=int,
73+
default=8,
74+
help="Batch size (per device) for the training dataloader.",
75+
)
76+
parser.add_argument(
77+
"--samples_per_file",
78+
type=int,
79+
default=64
80+
)
81+
parser.add_argument(
82+
"--flush_frequency",
83+
type=int,
84+
default=256,
85+
help="how often to save to parquet files"
86+
)
87+
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
88+
parser.add_argument("--max_height", type=int, default=480)
89+
parser.add_argument("--max_width", type=int, default=848)
90+
parser.add_argument("--video_length_tolerance_range", type=int, default=2.0)
91+
parser.add_argument("--group_frame", action="store_true") # TODO
92+
parser.add_argument("--group_resolution", action="store_true") # TODO
93+
parser.add_argument("--dataset", default="t2v")
94+
parser.add_argument("--train_fps", type=int, default=30)
95+
parser.add_argument("--use_image_num", type=int, default=0)
96+
parser.add_argument("--text_max_length", type=int, default=256)
97+
parser.add_argument("--speed_factor", type=float, default=1.0)
98+
parser.add_argument("--drop_short_ratio", type=float, default=1.0)
99+
# text encoder & vae & diffusion model
100+
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
101+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
102+
parser.add_argument("--cfg", type=float, default=0.0)
103+
parser.add_argument(
104+
"--output_dir",
105+
type=str,
106+
default=None,
107+
help="The output directory where the model predictions and checkpoints will be written.",
108+
)
109+
parser.add_argument(
110+
"--logging_dir",
111+
type=str,
112+
default="logs",
113+
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
114+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
115+
)
116+
117+
args = parser.parse_args()
118+
main(args)

fastvideo/v1/dataset/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
from torchvision import transforms
24
from torchvision.transforms import Lambda
35
from transformers import AutoTokenizer
@@ -25,8 +27,8 @@ def getdataset(args, start_idx=0) -> T2V_dataset:
2527
*resize_topcrop,
2628
norm_fun,
2729
])
28-
# tokenizer = AutoTokenizer.from_pretrained("/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl", cache_dir=args.cache_dir)
29-
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name,
30+
tokenizer_path = os.path.join(args.model_path, "tokenizer")
31+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
3032
cache_dir=args.cache_dir)
3133
if args.dataset == "t2v":
3234
return T2V_dataset(args,

0 commit comments

Comments
 (0)