Skip to content

Commit 842332f

Browse files
committed
Refactor DOPMerging class to include SimpleProfilerMixin and enhance model loading with profiling; update LightningFabricMixin to default to 1 device if no configuration is found.
1 parent 6027fa2 commit 842332f

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

examples/dop/distributed_dop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@
2020
}
2121

2222
algo: DOPMerging = DOPMerging.from_yaml(config_file)
23-
algo.num_ray_actors = 2 # set the number of ray actors to use for parallel merging
23+
algo.num_ray_actors = 2 # set the number of ray actors to use for parallel merging
2424
algo.run(BaseModelPool(models))

fusion_bench/method/dop/dop_general.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
2323
from fusion_bench.method.simple_average import simple_average
24-
from fusion_bench.mixins import LightningFabricMixin
24+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
2525
from fusion_bench.models.utils import named_leaf_modules
2626
from fusion_bench.utils import seed_everything_by_time
2727
from fusion_bench.utils.dtype import dtype_support_svd
@@ -35,7 +35,7 @@
3535

3636

3737
@auto_register_config
38-
class DOPMerging(BaseAlgorithm, LightningFabricMixin):
38+
class DOPMerging(LightningFabricMixin, SimpleProfilerMixin, BaseAlgorithm):
3939
"""
4040
Dual Projections for Balancing Stability and Plasticity (DOP) merging algorithm.
4141
@@ -167,18 +167,25 @@ def run(self, modelpool: BaseModelPool):
167167
f"--------- Optimizing {model_idx + 1}/{len(model_names)}-th with {model_name} ---------"
168168
)
169169
if model_idx == 0:
170-
merged_model = modelpool.load_model(model_names[0])
170+
print("Using the first model as the initial merged model.")
171+
with self.profile("loading models"):
172+
merged_model = modelpool.load_model(model_names[0])
171173
else:
172-
merged_model = self._layer_wise_optimize(
173-
model_names=["merged", model_name],
174-
pretrained_model=deepcopy(pretrained_model),
175-
finetuned_models={
176-
"merged": merged_model,
177-
model_name: modelpool.load_model(model_name),
178-
},
179-
model_idx=model_idx,
180-
)
174+
with self.profile("loading models"):
175+
finetuned_model = modelpool.load_model(model_name)
176+
with self.profile("DOP merging"):
177+
merged_model = self._layer_wise_optimize(
178+
model_names=["merged", model_name],
179+
pretrained_model=deepcopy(pretrained_model),
180+
finetuned_models={
181+
"merged": merged_model,
182+
model_name: finetuned_model,
183+
},
184+
model_idx=model_idx,
185+
)
186+
del finetuned_model
181187

188+
self.print_profile_summary()
182189
return merged_model
183190

184191
def _optimize_linear_layer(
@@ -246,12 +253,13 @@ def _layer_wise_optimize(
246253
module.weight.data = merged_weight.data
247254
else:
248255
if not self.ray_actor_pool.has_free():
249-
module_name, merged_weight = (
256+
returned_module_name, merged_weight = (
250257
self.ray_actor_pool.get_next_unordered()
251258
)
252-
pretrained_model.get_submodule(module_name).weight.data = (
253-
merged_weight
254-
)
259+
print(f"merged weight {returned_module_name} from ray actors.")
260+
pretrained_model.get_submodule(
261+
returned_module_name
262+
).weight.data = merged_weight
255263
self.ray_actor_pool.submit(
256264
lambda actor, kwargs: actor._optimize_linear_layer.remote(
257265
**kwargs
@@ -275,6 +283,7 @@ def _layer_wise_optimize(
275283
if self.num_ray_actors > 0:
276284
while self.ray_actor_pool.has_next():
277285
module_name, merged_weight = self.ray_actor_pool.get_next_unordered()
286+
print(f"merged weight {module_name} from ray actors.")
278287
pretrained_model.get_submodule(module_name).weight.data = merged_weight
279288

280289
return pretrained_model
@@ -360,7 +369,9 @@ def _optimize_weight(
360369
all_losses = [[], []]
361370
all_alphas = [[], []]
362371
for step_idx in tqdm(
363-
range(self.num_steps), desc=f"Optimizing {module_name} weight"
372+
range(self.num_steps),
373+
desc=f"Optimizing {module_name} weight",
374+
disable=self.num_ray_actors > 0,
364375
):
365376
# Scaling the loss functions based on the algorithm choice
366377
loss_data = {}
@@ -421,7 +432,9 @@ def _optimize_weight(
421432
# This is a naive weighted optimization
422433
optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
423434
for step_idx in tqdm(
424-
range(self.num_steps), desc=f"Optimizing {module_name} weight"
435+
range(self.num_steps),
436+
desc=f"Optimizing {module_name} weight",
437+
disable=self.num_ray_actors > 0,
425438
):
426439
loss = 0
427440
for i, finetuned_weight in enumerate(finetuned_weights.values()):

fusion_bench/mixins/lightning_fabric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def setup_lightning_fabric(self, config: DictConfig):
110110
"""
111111
if self._fabric_instance is None:
112112
if config.get("fabric", None) is None:
113-
log.warning("No fabric configuration found. use default settings.")
114-
self._fabric_instance = L.Fabric()
113+
log.warning("No fabric configuration found. use default settings. By default, use 1 device.")
114+
self._fabric_instance = L.Fabric(devices=1)
115115
else:
116116
self._fabric_instance = instantiate(config.fabric)
117117
if not _is_using_cli(): # if not using cli, launch the fabric

0 commit comments

Comments
 (0)