Skip to content

Commit 6f170a8

Browse files
SolitaryThinkerJerryZhou54kevin314BrianChen1129
authored
[Training] [5/n] Add single gpu training pipeline (hao-ai-lab#447)
Co-authored-by: JerryZhou54 <[email protected]> Co-authored-by: Wei Zhou <[email protected]> Co-authored-by: Kevin Lin <[email protected]> Co-authored-by: “BrianChen1129” <[email protected]>
1 parent 5ede860 commit 6f170a8

File tree

13 files changed

+1138
-61
lines changed

13 files changed

+1138
-61
lines changed

.github/workflows/pre-commit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
- uses: actions/checkout@v4
1111
- uses: actions/setup-python@v5
1212
with:
13-
python-version: "3.10"
13+
python-version: "3.12"
1414
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
1515
- run: echo "::add-matcher::.github/workflows/matchers/mypy.json"
1616
- uses: pre-commit/[email protected]

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ repos:
3333
args: [--in-place, --verbose]
3434
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
3535
- repo: https://github.com/astral-sh/ruff-pre-commit
36-
rev: v0.11.4
36+
rev: v0.11.12
3737
hooks:
3838
- id: ruff
3939
args: [--output-format, github, --fix]
@@ -48,7 +48,7 @@ repos:
4848
hooks:
4949
- id: isort
5050
- repo: https://github.com/jackdewinter/pymarkdown
51-
rev: v0.9.29
51+
rev: v0.9.30
5252
hooks:
5353
- id: pymarkdown
5454
args: [fix]

fastvideo/v1/configs/models/vaes/wanvae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __post_init__(self):
6363

6464
@dataclass
6565
class WanVAEConfig(VAEConfig):
66-
arch_config: VAEArchConfig = field(default_factory=WanVAEArchConfig)
66+
arch_config: WanVAEArchConfig = field(default_factory=WanVAEArchConfig)
6767
use_feature_cache: bool = True
6868

6969
use_tiling: bool = False

fastvideo/v1/distributed/parallel_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def recv_tensor_dict(
655655
tensor_dict[key] = value
656656
return tensor_dict
657657

658-
def barrier(self):
658+
def barrier(self) -> None:
659659
"""Barrier synchronization among the group.
660660
NOTE: don't use `device_group` here! `barrier` in NCCL is
661661
terrible because it is internally a broadcast operation with

fastvideo/v1/fastvideo_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ class TrainingArgs(FastVideoArgs):
478478
output_dir: str = ""
479479
checkpoints_total_limit: int = 0
480480
checkpointing_steps: int = 0
481+
resume_from_checkpoint: bool = False
481482
logging_dir: str = ""
482483

483484
# optimizer & scheduler

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 150 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,25 @@
55
This module defines the base class for pipelines that are composed of multiple stages.
66
"""
77

8+
import argparse
89
import os
910
from abc import ABC, abstractmethod
1011
from copy import deepcopy
11-
from typing import Any, Dict, List, Optional, cast
12+
from typing import Any, Dict, List, Optional, Union, cast
1213

1314
import torch
1415

15-
from fastvideo.v1.fastvideo_args import FastVideoArgs
16+
from fastvideo.v1.configs.pipelines import (PipelineConfig,
17+
get_pipeline_config_cls_for_name)
18+
from fastvideo.v1.distributed import (init_distributed_environment,
19+
initialize_model_parallel,
20+
model_parallel_is_initialized)
21+
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
1622
from fastvideo.v1.logger import init_logger
1723
from fastvideo.v1.models.loader.component_loader import PipelineComponentLoader
1824
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
1925
from fastvideo.v1.pipelines.stages import PipelineStage
20-
from fastvideo.v1.utils import (maybe_download_model,
26+
from fastvideo.v1.utils import (maybe_download_model, shallow_asdict,
2127
verify_model_config_and_directory)
2228

2329
logger = init_logger(__name__)
@@ -34,20 +40,35 @@ class ComposedPipelineBase(ABC):
3440

3541
is_video_pipeline: bool = False # To be overridden by video pipelines
3642
_required_config_modules: List[str] = []
43+
training_args: Optional[TrainingArgs] = None
44+
fastvideo_args: Optional[FastVideoArgs] = None
3745

3846
# TODO(will): args should support both inference args and training args
3947
def __init__(self,
4048
model_path: str,
4149
fastvideo_args: FastVideoArgs,
42-
config: Optional[Dict[str, Any]] = None):
50+
config: Optional[Dict[str, Any]] = None,
51+
required_config_modules: Optional[List[str]] = None):
4352
"""
4453
Initialize the pipeline. After __init__, the pipeline should be ready to
4554
use. The pipeline should be stateless and not hold any batch state.
4655
"""
56+
57+
if fastvideo_args.training_mode:
58+
assert isinstance(fastvideo_args, TrainingArgs)
59+
self.training_args = fastvideo_args
60+
assert self.training_args is not None
61+
else:
62+
self.fastvideo_args = fastvideo_args
63+
assert self.fastvideo_args is not None
64+
4765
self.model_path = model_path
4866
self._stages: List[PipelineStage] = []
4967
self._stage_name_mapping: Dict[str, PipelineStage] = {}
5068

69+
if required_config_modules is not None:
70+
self._required_config_modules = required_config_modules
71+
5172
if self._required_config_modules is None:
5273
raise NotImplementedError(
5374
"Subclass must set _required_config_modules")
@@ -59,16 +80,124 @@ def __init__(self,
5980
else:
6081
self.config = config
6182

83+
self.maybe_init_distributed_environment(fastvideo_args)
84+
6285
# Load modules directly in initialization
6386
logger.info("Loading pipeline modules...")
6487
self.modules = self.load_modules(fastvideo_args)
6588

89+
if fastvideo_args.training_mode:
90+
assert self.training_args is not None
91+
if self.training_args.log_validation:
92+
self.initialize_validation_pipeline(self.training_args)
93+
self.initialize_training_pipeline(self.training_args)
94+
6695
self.initialize_pipeline(fastvideo_args)
6796

68-
logger.info("Creating pipeline stages...")
69-
self.create_pipeline_stages(fastvideo_args)
97+
if not fastvideo_args.training_mode:
98+
logger.info("Creating pipeline stages...")
99+
self.create_pipeline_stages(fastvideo_args)
100+
101+
def initialize_training_pipeline(self, training_args: TrainingArgs):
102+
raise NotImplementedError(
103+
"if training_mode is True, the pipeline must implement this method")
104+
105+
def initialize_validation_pipeline(self, training_args: TrainingArgs):
106+
raise NotImplementedError(
107+
"if log_validation is True, the pipeline must implement this method"
108+
)
109+
110+
@classmethod
111+
def from_pretrained(cls,
112+
model_path: str,
113+
device: Optional[str] = None,
114+
torch_dtype: Optional[torch.dtype] = None,
115+
pipeline_config: Optional[
116+
Union[str
117+
| PipelineConfig]] = None,
118+
args: Optional[argparse.Namespace] = None,
119+
required_config_modules: Optional[List[str]] = None,
120+
**kwargs) -> "ComposedPipelineBase":
121+
config = None
122+
# 1. If users provide a pipeline config, it will override the default pipeline config
123+
if isinstance(pipeline_config, PipelineConfig):
124+
config = pipeline_config
125+
else:
126+
config_cls = get_pipeline_config_cls_for_name(model_path)
127+
if config_cls is not None:
128+
config = config_cls()
129+
if isinstance(pipeline_config, str):
130+
config.load_from_json(pipeline_config)
131+
132+
# 2. If users also provide some kwargs, it will override the pipeline config.
133+
# The user kwargs shouldn't contain model config parameters!
134+
if config is None:
135+
logger.warning("No config found for model %s, using default config",
136+
model_path)
137+
config_args = kwargs
138+
else:
139+
config_args = shallow_asdict(config)
140+
config_args.update(kwargs)
141+
142+
if args is None or args.inference_mode:
143+
fastvideo_args = FastVideoArgs(model_path=model_path,
144+
device_str=device or "cuda" if
145+
torch.cuda.is_available() else "cpu",
146+
**config_args)
147+
148+
fastvideo_args.model_path = model_path
149+
fastvideo_args.device_str = device or "cuda" if torch.cuda.is_available(
150+
) else "cpu"
151+
for key, value in config_args.items():
152+
setattr(fastvideo_args, key, value)
153+
else:
154+
assert args is not None, "args must be provided for training mode"
155+
fastvideo_args = TrainingArgs.from_cli_args(args)
156+
# TODO(will): fix this so that its not so ugly
157+
fastvideo_args.model_path = model_path
158+
fastvideo_args.device_str = device or "cuda" if torch.cuda.is_available(
159+
) else "cpu"
160+
for key, value in config_args.items():
161+
setattr(fastvideo_args, key, value)
162+
163+
fastvideo_args.use_cpu_offload = False
164+
fastvideo_args.inference_mode = False
165+
166+
logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)
167+
168+
fastvideo_args.check_fastvideo_args()
169+
170+
return cls(model_path,
171+
fastvideo_args,
172+
required_config_modules=required_config_modules)
173+
174+
def maybe_init_distributed_environment(self, fastvideo_args: FastVideoArgs):
175+
if model_parallel_is_initialized():
176+
return
177+
local_rank = int(os.environ.get("LOCAL_RANK", -1))
178+
world_size = int(os.environ.get("WORLD_SIZE", -1))
179+
rank = int(os.environ.get("RANK", -1))
180+
181+
if local_rank == -1 or world_size == -1 or rank == -1:
182+
raise ValueError(
183+
"Local rank, world size, and rank must be set. Use torchrun to launch the script."
184+
)
70185

71-
def get_module(self, module_name: str) -> Any:
186+
torch.cuda.set_device(local_rank)
187+
init_distributed_environment(world_size=world_size,
188+
rank=rank,
189+
local_rank=local_rank)
190+
assert fastvideo_args.tp_size is not None, "tp_size must be set"
191+
assert fastvideo_args.sp_size is not None, "sp_size must be set"
192+
initialize_model_parallel(
193+
tensor_model_parallel_size=fastvideo_args.tp_size,
194+
sequence_model_parallel_size=fastvideo_args.sp_size)
195+
device = torch.device(f"cuda:{local_rank}")
196+
fastvideo_args.device = device
197+
198+
def get_module(self, module_name: str, default_value: Any = None) -> Any:
199+
if module_name not in self.modules:
200+
return default_value
72201
return self.modules[module_name]
73202

74203
def add_module(self, module_name: str, module: Any):
@@ -114,6 +243,12 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
114243
"""
115244
raise NotImplementedError
116245

246+
def create_training_stages(self, training_args: TrainingArgs):
247+
"""
248+
Create the training pipeline stages.
249+
"""
250+
raise NotImplementedError
251+
117252
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
118253
"""
119254
Initialize the pipeline.
@@ -136,19 +271,21 @@ def load_modules(self, fastvideo_args: FastVideoArgs) -> Dict[str, Any]:
136271
modules_config
137272
) > 1, "model_index.json must contain at least one pipeline module"
138273

139-
required_modules = [
140-
"vae", "text_encoder", "transformer", "scheduler", "tokenizer"
141-
]
142-
for module_name in required_modules:
274+
for module_name in self.required_config_modules:
143275
if module_name not in modules_config:
144276
raise ValueError(
145277
f"model_index.json must contain a {module_name} module")
146-
logger.info("Diffusers config passed sanity checks")
147278

148279
# all the component models used by the pipeline
280+
required_modules = self.required_config_modules
281+
logger.info("Loading required modules: %s", required_modules)
282+
149283
modules = {}
150284
for module_name, (transformers_or_diffusers,
151285
architecture) in modules_config.items():
286+
if module_name not in required_modules:
287+
logger.info("Skipping module %s", module_name)
288+
continue
152289
component_model_path = os.path.join(self.model_path, module_name)
153290
module = PipelineComponentLoader.load_module(
154291
module_name=module_name,
@@ -164,7 +301,6 @@ def load_modules(self, fastvideo_args: FastVideoArgs) -> Dict[str, Any]:
164301
logger.warning("Overwriting module %s", module_name)
165302
modules[module_name] = module
166303

167-
required_modules = self.required_config_modules
168304
# Check if all required modules were loaded
169305
for module_name in required_modules:
170306
if module_name not in modules or modules[module_name] is None:
@@ -198,7 +334,7 @@ def forward(
198334
# Execute each stage
199335
logger.info("Running pipeline stages: %s",
200336
self._stage_name_mapping.keys())
201-
logger.info("Batch: %s", batch)
337+
# logger.info("Batch: %s", batch)
202338
for stage in self.stages:
203339
batch = stage(batch, fastvideo_args)
204340

fastvideo/v1/pipelines/training_utils.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

fastvideo/v1/pipelines/wan/wan_pipeline.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,33 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
4848
self.add_stage(stage_name="latent_preparation_stage",
4949
stage=LatentPreparationStage(
5050
scheduler=self.get_module("scheduler"),
51-
transformer=self.get_module("transformer")))
51+
transformer=self.get_module("transformer", None)))
52+
53+
self.add_stage(stage_name="denoising_stage",
54+
stage=DenoisingStage(
55+
transformer=self.get_module("transformer"),
56+
scheduler=self.get_module("scheduler")))
57+
58+
self.add_stage(stage_name="decoding_stage",
59+
stage=DecodingStage(vae=self.get_module("vae")))
60+
61+
62+
class WanValidationPipeline(ComposedPipelineBase):
63+
"""
64+
Validation pipeline for Wan2.1, assumes that the input are preprocess latents.
65+
"""
66+
_required_config_modules = ["vae", "scheduler"]
67+
68+
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
69+
"""Set up pipeline stages with proper dependency injection."""
70+
self.add_stage(stage_name="timestep_preparation_stage",
71+
stage=TimestepPreparationStage(
72+
scheduler=self.get_module("scheduler")))
73+
74+
self.add_stage(stage_name="latent_preparation_stage",
75+
stage=LatentPreparationStage(
76+
scheduler=self.get_module("scheduler"),
77+
transformer=self.get_module("transformer", None)))
5278

5379
self.add_stage(stage_name="denoising_stage",
5480
stage=DenoisingStage(

fastvideo/v1/training/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)