Skip to content

Commit 259cdc6

Browse files
committed
Add distributed_ray example and enhance DOPMerging with GPU support
1 parent 4b0637a commit 259cdc6

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

examples/dop/distributed_ray.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from transformers import CLIPVisionModel
2+
3+
from fusion_bench import BaseModelPool
4+
from fusion_bench.constants.paths import DEFAULT_CONFIG_PATH
5+
from fusion_bench.method.dop.dop_general import DOPMerging
6+
from fusion_bench.utils import timeit_context
7+
8+
config_file = DEFAULT_CONFIG_PATH / "method/dop/dop_general.yaml"
9+
10+
11+
with timeit_context("loading models"):
12+
models = {
13+
"_pretrained_": CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32"),
14+
"sun397": CLIPVisionModel.from_pretrained(
15+
"tanganke/clip-vit-base-patch32_sun397"
16+
),
17+
"stanford-cars": CLIPVisionModel.from_pretrained(
18+
"tanganke/clip-vit-base-patch32_stanford-cars"
19+
),
20+
}
21+
22+
algo: DOPMerging = DOPMerging.from_yaml(config_file)
23+
algo.num_ray_actors = 2 # set the number of ray actors to use for parallel merging
24+
algo.run(BaseModelPool(models))

fusion_bench/method/dop/dop_general.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,13 @@ def run(self, modelpool: BaseModelPool):
140140
ray.init()
141141

142142
# create actors
143+
if self.fabric.device.type == "cuda":
144+
actor_options = {"num_gpus": 1}
145+
else:
146+
actor_options = {}
143147
self.ray_actor_pool = ActorPool(
144148
[
145-
DOPMergingActor.remote(**self.config)
149+
DOPMergingActor.options(**actor_options).remote(**self.config)
146150
for _ in range(self.num_ray_actors)
147151
]
148152
)
@@ -250,7 +254,7 @@ def _layer_wise_optimize(
250254
)
251255
self.ray_actor_pool.submit(
252256
lambda actor, kwargs: actor._optimize_linear_layer.remote(
253-
*kwargs
257+
**kwargs
254258
),
255259
{
256260
"module_name": module_name,

fusion_bench/models/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def is_leaf_module(module: nn.Module) -> bool:
1515
def named_leaf_modules(
1616
module: nn.Module,
1717
prefix: str = "",
18+
ignore_empty: bool = True,
1819
) -> Iterable[tuple[str, nn.Module]]:
1920
"""
2021
Recursively find the leaf modules in a module.
@@ -28,6 +29,8 @@ def named_leaf_modules(
2829
"""
2930
for name, submodule in module.named_modules(prefix=prefix):
3031
if is_leaf_module(submodule):
32+
if ignore_empty and len(list(submodule.parameters())) == 0:
33+
continue
3134
yield name, submodule
3235

3336

0 commit comments

Comments
 (0)