File tree Expand file tree Collapse file tree 3 files changed +33
-2
lines changed
Expand file tree Collapse file tree 3 files changed +33
-2
lines changed Original file line number Diff line number Diff line change 1+ from transformers import CLIPVisionModel
2+
3+ from fusion_bench import BaseModelPool
4+ from fusion_bench .constants .paths import DEFAULT_CONFIG_PATH
5+ from fusion_bench .method .dop .dop_general import DOPMerging
6+ from fusion_bench .utils import timeit_context
7+
8+ config_file = DEFAULT_CONFIG_PATH / "method/dop/dop_general.yaml"
9+
10+
11+ with timeit_context ("loading models" ):
12+ models = {
13+ "_pretrained_" : CLIPVisionModel .from_pretrained ("openai/clip-vit-base-patch32" ),
14+ "sun397" : CLIPVisionModel .from_pretrained (
15+ "tanganke/clip-vit-base-patch32_sun397"
16+ ),
17+ "stanford-cars" : CLIPVisionModel .from_pretrained (
18+ "tanganke/clip-vit-base-patch32_stanford-cars"
19+ ),
20+ }
21+
22+ algo : DOPMerging = DOPMerging .from_yaml (config_file )
23+ algo .num_ray_actors = 2 # set the number of ray actors to use for parallel merging
24+ algo .run (BaseModelPool (models ))
Original file line number Diff line number Diff line change @@ -140,9 +140,13 @@ def run(self, modelpool: BaseModelPool):
140140 ray .init ()
141141
142142 # create actors
143+ if self .fabric .device .type == "cuda" :
144+ actor_options = {"num_gpus" : 1 }
145+ else :
146+ actor_options = {}
143147 self .ray_actor_pool = ActorPool (
144148 [
145- DOPMergingActor .remote (** self .config )
149+ DOPMergingActor .options ( ** actor_options ). remote (** self .config )
146150 for _ in range (self .num_ray_actors )
147151 ]
148152 )
@@ -250,7 +254,7 @@ def _layer_wise_optimize(
250254 )
251255 self .ray_actor_pool .submit (
252256 lambda actor , kwargs : actor ._optimize_linear_layer .remote (
253- * kwargs
257+ ** kwargs
254258 ),
255259 {
256260 "module_name" : module_name ,
Original file line number Diff line number Diff line change @@ -15,6 +15,7 @@ def is_leaf_module(module: nn.Module) -> bool:
1515def named_leaf_modules (
1616 module : nn .Module ,
1717 prefix : str = "" ,
18+ ignore_empty : bool = True ,
1819) -> Iterable [tuple [str , nn .Module ]]:
1920 """
2021 Recursively find the leaf modules in a module.
@@ -28,6 +29,8 @@ def named_leaf_modules(
2829 """
2930 for name , submodule in module .named_modules (prefix = prefix ):
3031 if is_leaf_module (submodule ):
32+ if ignore_empty and len (list (submodule .parameters ())) == 0 :
33+ continue
3134 yield name , submodule
3235
3336
You can’t perform that action at this time.
0 commit comments