1515import argparse
1616import logging
1717import os
18+ import math
1819import random
1920import time
2021import json
@@ -85,6 +86,11 @@ def parse_args():
8586 type = float ,
8687 default = 1.0 ,
8788 help = "width mult you want to export" )
89+ parser .add_argument (
90+ '--depth_mult' ,
91+ type = float ,
92+ default = 1.0 ,
93+ help = "depth mult you want to export" )
8894 args = parser .parse_args ()
8995 return args
9096
@@ -106,6 +112,18 @@ def do_train(args):
106112 model_class , tokenizer_class = MODEL_CLASSES [args .model_type ]
107113 config_path = os .path .join (args .model_name_or_path , 'model_config.json' )
108114 cfg_dict = dict (json .loads (open (config_path ).read ()))
115+
116+ if args .depth_mult < 1.0 :
117+ depth = round (cfg_dict ["init_args" ][0 ]['num_hidden_layers' ] * args .depth_mult )
118+ cfg_dict ["init_args" ][0 ]['num_hidden_layers' ] = depth
119+ kept_layers_index = {}
120+ for idx , i in enumerate (range (1 , depth + 1 )):
121+ kept_layers_index [idx ] = math .floor (i / args .depth_mult ) - 1
122+
123+ os .rename (config_path , config_path + '_bak' )
124+ with open (config_path , "w" , encoding = "utf-8" ) as f :
125+ f .write (json .dumps (cfg_dict , ensure_ascii = False ))
126+
109127 num_labels = cfg_dict ['num_classes' ]
110128
111129 model = model_class .from_pretrained (
@@ -114,14 +132,24 @@ def do_train(args):
114132 origin_model = model_class .from_pretrained (
115133 args .model_name_or_path , num_classes = num_labels )
116134
135+ os .rename (config_path + '_bak' , config_path )
136+
117137 sp_config = supernet (expand_ratio = [1.0 , args .width_mult ])
118138 model = Convert (sp_config ).convert (model )
119139
120140 ofa_model = OFA (model )
121141
122142 sd = paddle .load (
123143 os .path .join (args .model_name_or_path , 'model_state.pdparams' ))
124- ofa_model .model .set_state_dict (sd )
144+
145+ for name , params in ofa_model .model .named_parameters ():
146+ if 'encoder' not in name :
147+ params .set_value (sd [name ])
148+ else :
149+ idx = int (name .strip ().split ('.' )[3 ])
150+ mapping_name = name .replace ('.' + str (idx )+ '.' , '.' + str (kept_layers_index [idx ])+ '.' )
151+ params .set_value (sd [mapping_name ])
152+
125153 best_config = utils .dynabert_config (ofa_model , args .width_mult )
126154 ofa_model .export (
127155 best_config ,
0 commit comments