Skip to content

Commit bd727d0

Browse files
authored
add depth export (PaddlePaddle#853)
1 parent 5e92215 commit bd727d0

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

examples/model_compression/ofa/export_model.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import argparse
1616
import logging
1717
import os
18+
import math
1819
import random
1920
import time
2021
import json
@@ -85,6 +86,11 @@ def parse_args():
8586
type=float,
8687
default=1.0,
8788
help="width mult you want to export")
89+
parser.add_argument(
90+
'--depth_mult',
91+
type=float,
92+
default=1.0,
93+
help="depth mult you want to export")
8894
args = parser.parse_args()
8995
return args
9096

@@ -106,6 +112,18 @@ def do_train(args):
106112
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
107113
config_path = os.path.join(args.model_name_or_path, 'model_config.json')
108114
cfg_dict = dict(json.loads(open(config_path).read()))
115+
116+
if args.depth_mult < 1.0:
117+
depth = round(cfg_dict["init_args"][0]['num_hidden_layers'] * args.depth_mult)
118+
cfg_dict["init_args"][0]['num_hidden_layers'] = depth
119+
kept_layers_index = {}
120+
for idx, i in enumerate(range(1, depth+1)):
121+
kept_layers_index[idx] = math.floor(i / args.depth_mult) - 1
122+
123+
os.rename(config_path, config_path+'_bak')
124+
with open(config_path, "w", encoding="utf-8") as f:
125+
f.write(json.dumps(cfg_dict, ensure_ascii=False))
126+
109127
num_labels = cfg_dict['num_classes']
110128

111129
model = model_class.from_pretrained(
@@ -114,14 +132,24 @@ def do_train(args):
114132
origin_model = model_class.from_pretrained(
115133
args.model_name_or_path, num_classes=num_labels)
116134

135+
os.rename(config_path+'_bak', config_path)
136+
117137
sp_config = supernet(expand_ratio=[1.0, args.width_mult])
118138
model = Convert(sp_config).convert(model)
119139

120140
ofa_model = OFA(model)
121141

122142
sd = paddle.load(
123143
os.path.join(args.model_name_or_path, 'model_state.pdparams'))
124-
ofa_model.model.set_state_dict(sd)
144+
145+
for name, params in ofa_model.model.named_parameters():
146+
if 'encoder' not in name:
147+
params.set_value(sd[name])
148+
else:
149+
idx = int(name.strip().split('.')[3])
150+
mapping_name = name.replace('.'+str(idx)+'.', '.'+str(kept_layers_index[idx])+'.')
151+
params.set_value(sd[mapping_name])
152+
125153
best_config = utils.dynabert_config(ofa_model, args.width_mult)
126154
ofa_model.export(
127155
best_config,

0 commit comments

Comments
 (0)