2121
2222from fusion_bench import BaseAlgorithm , BaseModelPool , auto_register_config
2323from fusion_bench .method .simple_average import simple_average
24- from fusion_bench .mixins import LightningFabricMixin
24+ from fusion_bench .mixins import LightningFabricMixin , SimpleProfilerMixin
2525from fusion_bench .models .utils import named_leaf_modules
2626from fusion_bench .utils import seed_everything_by_time
2727from fusion_bench .utils .dtype import dtype_support_svd
3535
3636
3737@auto_register_config
38- class DOPMerging (BaseAlgorithm , LightningFabricMixin ):
38+ class DOPMerging (LightningFabricMixin , SimpleProfilerMixin , BaseAlgorithm ):
3939 """
4040 Dual Projections for Balancing Stability and Plasticity (DOP) merging algorithm.
4141
@@ -167,18 +167,25 @@ def run(self, modelpool: BaseModelPool):
167167 f"--------- Optimizing { model_idx + 1 } /{ len (model_names )} -th with { model_name } ---------"
168168 )
169169 if model_idx == 0 :
170- merged_model = modelpool .load_model (model_names [0 ])
170+ print ("Using the first model as the initial merged model." )
171+ with self .profile ("loading models" ):
172+ merged_model = modelpool .load_model (model_names [0 ])
171173 else :
172- merged_model = self ._layer_wise_optimize (
173- model_names = ["merged" , model_name ],
174- pretrained_model = deepcopy (pretrained_model ),
175- finetuned_models = {
176- "merged" : merged_model ,
177- model_name : modelpool .load_model (model_name ),
178- },
179- model_idx = model_idx ,
180- )
174+ with self .profile ("loading models" ):
175+ finetuned_model = modelpool .load_model (model_name )
176+ with self .profile ("DOP merging" ):
177+ merged_model = self ._layer_wise_optimize (
178+ model_names = ["merged" , model_name ],
179+ pretrained_model = deepcopy (pretrained_model ),
180+ finetuned_models = {
181+ "merged" : merged_model ,
182+ model_name : finetuned_model ,
183+ },
184+ model_idx = model_idx ,
185+ )
186+ del finetuned_model
181187
188+ self .print_profile_summary ()
182189 return merged_model
183190
184191 def _optimize_linear_layer (
@@ -246,12 +253,13 @@ def _layer_wise_optimize(
246253 module .weight .data = merged_weight .data
247254 else :
248255 if not self .ray_actor_pool .has_free ():
249- module_name , merged_weight = (
256+ returned_module_name , merged_weight = (
250257 self .ray_actor_pool .get_next_unordered ()
251258 )
252- pretrained_model .get_submodule (module_name ).weight .data = (
253- merged_weight
254- )
259+ print (f"merged weight { returned_module_name } from ray actors." )
260+ pretrained_model .get_submodule (
261+ returned_module_name
262+ ).weight .data = merged_weight
255263 self .ray_actor_pool .submit (
256264 lambda actor , kwargs : actor ._optimize_linear_layer .remote (
257265 ** kwargs
@@ -275,6 +283,7 @@ def _layer_wise_optimize(
275283 if self .num_ray_actors > 0 :
276284 while self .ray_actor_pool .has_next ():
277285 module_name , merged_weight = self .ray_actor_pool .get_next_unordered ()
286+ print (f"merged weight { module_name } from ray actors." )
278287 pretrained_model .get_submodule (module_name ).weight .data = merged_weight
279288
280289 return pretrained_model
@@ -360,7 +369,9 @@ def _optimize_weight(
360369 all_losses = [[], []]
361370 all_alphas = [[], []]
362371 for step_idx in tqdm (
363- range (self .num_steps ), desc = f"Optimizing { module_name } weight"
372+ range (self .num_steps ),
373+ desc = f"Optimizing { module_name } weight" ,
374+ disable = self .num_ray_actors > 0 ,
364375 ):
365376 # Scaling the loss functions based on the algorithm choice
366377 loss_data = {}
@@ -421,7 +432,9 @@ def _optimize_weight(
421432 # This is a naive weighted optimization
422433 optimizer = torch .optim .Adam ([merged_weight ], lr = self .lr )
423434 for step_idx in tqdm (
424- range (self .num_steps ), desc = f"Optimizing { module_name } weight"
435+ range (self .num_steps ),
436+ desc = f"Optimizing { module_name } weight" ,
437+ disable = self .num_ray_actors > 0 ,
425438 ):
426439 loss = 0
427440 for i , finetuned_weight in enumerate (finetuned_weights .values ()):
0 commit comments