55This module defines the base class for pipelines that are composed of multiple stages.
66"""
77
8+ import argparse
89import os
910from abc import ABC , abstractmethod
1011from copy import deepcopy
11- from typing import Any , Dict , List , Optional , cast
12+ from typing import Any , Dict , List , Optional , Union , cast
1213
1314import torch
1415
15- from fastvideo .v1 .fastvideo_args import FastVideoArgs
16+ from fastvideo .v1 .configs .pipelines import (PipelineConfig ,
17+ get_pipeline_config_cls_for_name )
18+ from fastvideo .v1 .distributed import (init_distributed_environment ,
19+ initialize_model_parallel ,
20+ model_parallel_is_initialized )
21+ from fastvideo .v1 .fastvideo_args import FastVideoArgs , TrainingArgs
1622from fastvideo .v1 .logger import init_logger
1723from fastvideo .v1 .models .loader .component_loader import PipelineComponentLoader
1824from fastvideo .v1 .pipelines .pipeline_batch_info import ForwardBatch
1925from fastvideo .v1 .pipelines .stages import PipelineStage
20- from fastvideo .v1 .utils import (maybe_download_model ,
26+ from fastvideo .v1 .utils import (maybe_download_model , shallow_asdict ,
2127 verify_model_config_and_directory )
2228
2329logger = init_logger (__name__ )
@@ -34,20 +40,35 @@ class ComposedPipelineBase(ABC):
3440
3541 is_video_pipeline : bool = False # To be overridden by video pipelines
3642 _required_config_modules : List [str ] = []
43+ training_args : Optional [TrainingArgs ] = None
44+ fastvideo_args : Optional [FastVideoArgs ] = None
3745
3846 # TODO(will): args should support both inference args and training args
3947 def __init__ (self ,
4048 model_path : str ,
4149 fastvideo_args : FastVideoArgs ,
42- config : Optional [Dict [str , Any ]] = None ):
50+ config : Optional [Dict [str , Any ]] = None ,
51+ required_config_modules : Optional [List [str ]] = None ):
4352 """
4453 Initialize the pipeline. After __init__, the pipeline should be ready to
4554 use. The pipeline should be stateless and not hold any batch state.
4655 """
56+
57+ if fastvideo_args .training_mode :
58+ assert isinstance (fastvideo_args , TrainingArgs )
59+ self .training_args = fastvideo_args
60+ assert self .training_args is not None
61+ else :
62+ self .fastvideo_args = fastvideo_args
63+ assert self .fastvideo_args is not None
64+
4765 self .model_path = model_path
4866 self ._stages : List [PipelineStage ] = []
4967 self ._stage_name_mapping : Dict [str , PipelineStage ] = {}
5068
69+ if required_config_modules is not None :
70+ self ._required_config_modules = required_config_modules
71+
5172 if self ._required_config_modules is None :
5273 raise NotImplementedError (
5374 "Subclass must set _required_config_modules" )
@@ -59,16 +80,124 @@ def __init__(self,
5980 else :
6081 self .config = config
6182
83+ self .maybe_init_distributed_environment (fastvideo_args )
84+
6285 # Load modules directly in initialization
6386 logger .info ("Loading pipeline modules..." )
6487 self .modules = self .load_modules (fastvideo_args )
6588
89+ if fastvideo_args .training_mode :
90+ assert self .training_args is not None
91+ if self .training_args .log_validation :
92+ self .initialize_validation_pipeline (self .training_args )
93+ self .initialize_training_pipeline (self .training_args )
94+
6695 self .initialize_pipeline (fastvideo_args )
6796
68- logger .info ("Creating pipeline stages..." )
69- self .create_pipeline_stages (fastvideo_args )
97+ if not fastvideo_args .training_mode :
98+ logger .info ("Creating pipeline stages..." )
99+ self .create_pipeline_stages (fastvideo_args )
100+
101+ def initialize_training_pipeline (self , training_args : TrainingArgs ):
102+ raise NotImplementedError (
103+ "if training_mode is True, the pipeline must implement this method" )
104+
105+ def initialize_validation_pipeline (self , training_args : TrainingArgs ):
106+ raise NotImplementedError (
107+ "if log_validation is True, the pipeline must implement this method"
108+ )
109+
110+ @classmethod
111+ def from_pretrained (cls ,
112+ model_path : str ,
113+ device : Optional [str ] = None ,
114+ torch_dtype : Optional [torch .dtype ] = None ,
115+ pipeline_config : Optional [
116+ Union [str
117+ | PipelineConfig ]] = None ,
118+ args : Optional [argparse .Namespace ] = None ,
119+ required_config_modules : Optional [List [str ]] = None ,
120+ ** kwargs ) -> "ComposedPipelineBase" :
121+ config = None
122+ # 1. If users provide a pipeline config, it will override the default pipeline config
123+ if isinstance (pipeline_config , PipelineConfig ):
124+ config = pipeline_config
125+ else :
126+ config_cls = get_pipeline_config_cls_for_name (model_path )
127+ if config_cls is not None :
128+ config = config_cls ()
129+ if isinstance (pipeline_config , str ):
130+ config .load_from_json (pipeline_config )
131+
132+ # 2. If users also provide some kwargs, it will override the pipeline config.
133+ # The user kwargs shouldn't contain model config parameters!
134+ if config is None :
135+ logger .warning ("No config found for model %s, using default config" ,
136+ model_path )
137+ config_args = kwargs
138+ else :
139+ config_args = shallow_asdict (config )
140+ config_args .update (kwargs )
141+
142+ if args is None or args .inference_mode :
143+ fastvideo_args = FastVideoArgs (model_path = model_path ,
144+ device_str = device or "cuda" if
145+ torch .cuda .is_available () else "cpu" ,
146+ ** config_args )
147+
148+ fastvideo_args .model_path = model_path
149+ fastvideo_args .device_str = device or "cuda" if torch .cuda .is_available (
150+ ) else "cpu"
151+ for key , value in config_args .items ():
152+ setattr (fastvideo_args , key , value )
153+ else :
154+ assert args is not None , "args must be provided for training mode"
155+ fastvideo_args = TrainingArgs .from_cli_args (args )
156+ # TODO(will): fix this so that its not so ugly
157+ fastvideo_args .model_path = model_path
158+ fastvideo_args .device_str = device or "cuda" if torch .cuda .is_available (
159+ ) else "cpu"
160+ for key , value in config_args .items ():
161+ setattr (fastvideo_args , key , value )
162+
163+ fastvideo_args .use_cpu_offload = False
164+ fastvideo_args .inference_mode = False
165+
166+ logger .info ("fastvideo_args in from_pretrained: %s" , fastvideo_args )
167+
168+ fastvideo_args .check_fastvideo_args ()
169+
170+ return cls (model_path ,
171+ fastvideo_args ,
172+ required_config_modules = required_config_modules )
173+
174+ def maybe_init_distributed_environment (self , fastvideo_args : FastVideoArgs ):
175+ if model_parallel_is_initialized ():
176+ return
177+ local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
178+ world_size = int (os .environ .get ("WORLD_SIZE" , - 1 ))
179+ rank = int (os .environ .get ("RANK" , - 1 ))
180+
181+ if local_rank == - 1 or world_size == - 1 or rank == - 1 :
182+ raise ValueError (
183+ "Local rank, world size, and rank must be set. Use torchrun to launch the script."
184+ )
70185
71- def get_module (self , module_name : str ) -> Any :
186+ torch .cuda .set_device (local_rank )
187+ init_distributed_environment (world_size = world_size ,
188+ rank = rank ,
189+ local_rank = local_rank )
190+ assert fastvideo_args .tp_size is not None , "tp_size must be set"
191+ assert fastvideo_args .sp_size is not None , "sp_size must be set"
192+ initialize_model_parallel (
193+ tensor_model_parallel_size = fastvideo_args .tp_size ,
194+ sequence_model_parallel_size = fastvideo_args .sp_size )
195+ device = torch .device (f"cuda:{ local_rank } " )
196+ fastvideo_args .device = device
197+
198+ def get_module (self , module_name : str , default_value : Any = None ) -> Any :
199+ if module_name not in self .modules :
200+ return default_value
72201 return self .modules [module_name ]
73202
74203 def add_module (self , module_name : str , module : Any ):
@@ -114,6 +243,12 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
114243 """
115244 raise NotImplementedError
116245
246+ def create_training_stages (self , training_args : TrainingArgs ):
247+ """
248+ Create the training pipeline stages.
249+ """
250+ raise NotImplementedError
251+
117252 def initialize_pipeline (self , fastvideo_args : FastVideoArgs ):
118253 """
119254 Initialize the pipeline.
@@ -136,19 +271,21 @@ def load_modules(self, fastvideo_args: FastVideoArgs) -> Dict[str, Any]:
136271 modules_config
137272 ) > 1 , "model_index.json must contain at least one pipeline module"
138273
139- required_modules = [
140- "vae" , "text_encoder" , "transformer" , "scheduler" , "tokenizer"
141- ]
142- for module_name in required_modules :
274+ for module_name in self .required_config_modules :
143275 if module_name not in modules_config :
144276 raise ValueError (
145277 f"model_index.json must contain a { module_name } module" )
146- logger .info ("Diffusers config passed sanity checks" )
147278
148279 # all the component models used by the pipeline
280+ required_modules = self .required_config_modules
281+ logger .info ("Loading required modules: %s" , required_modules )
282+
149283 modules = {}
150284 for module_name , (transformers_or_diffusers ,
151285 architecture ) in modules_config .items ():
286+ if module_name not in required_modules :
287+ logger .info ("Skipping module %s" , module_name )
288+ continue
152289 component_model_path = os .path .join (self .model_path , module_name )
153290 module = PipelineComponentLoader .load_module (
154291 module_name = module_name ,
@@ -164,7 +301,6 @@ def load_modules(self, fastvideo_args: FastVideoArgs) -> Dict[str, Any]:
164301 logger .warning ("Overwriting module %s" , module_name )
165302 modules [module_name ] = module
166303
167- required_modules = self .required_config_modules
168304 # Check if all required modules were loaded
169305 for module_name in required_modules :
170306 if module_name not in modules or modules [module_name ] is None :
@@ -198,7 +334,7 @@ def forward(
198334 # Execute each stage
199335 logger .info ("Running pipeline stages: %s" ,
200336 self ._stage_name_mapping .keys ())
201- logger .info ("Batch: %s" , batch )
337+ # logger.info("Batch: %s", batch)
202338 for stage in self .stages :
203339 batch = stage (batch , fastvideo_args )
204340
0 commit comments