Skip to content

Commit 0ef02af

Browse files
committed
Forgot weight dumping script
1 parent 5917870 commit 0ef02af

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
import sys
3+
import argparse
4+
5+
import torch
6+
from torch import nn
7+
8+
9+
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
10+
import wexchange.torch
11+
12+
import rnnoise
13+
#from models import model_dict
14+
15+
unquantized = [ 'conv1', 'dense_out', 'vad_dense' ]
16+
17+
description=f"""
18+
This is an unsafe dumping script for RNNoise models. It assumes that all weights are included in Linear, Conv1d or GRU layer
19+
and will fail to export any other weights.
20+
21+
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
22+
{unquantized}.
23+
24+
Modify this script manually if adjustments are needed.
25+
"""
26+
27+
parser = argparse.ArgumentParser(description=description)
28+
parser.add_argument('weightfile', type=str, help='weight file path')
29+
parser.add_argument('export_folder', type=str)
30+
parser.add_argument('--export-filename', type=str, default='rnnoise_data', help='filename for source and header file (.c and .h will be added), defaults to rnnoise_data')
31+
parser.add_argument('--struct-name', type=str, default='RNNoise', help='name for C struct, defaults to RNNoise')
32+
parser.add_argument('--quantize', action='store_true', help='apply quantization')
33+
34+
if __name__ == "__main__":
35+
args = parser.parse_args()
36+
37+
print(f"loading weights from {args.weightfile}...")
38+
saved_gen= torch.load(args.weightfile, map_location='cpu')
39+
saved_gen['model_args'] = ()
40+
#saved_gen['model_kwargs'] = {'cond_size': 256, 'gamma': 0.9}
41+
42+
model = rnnoise.RNNoise(*saved_gen['model_args'], **saved_gen['model_kwargs'])
43+
model.load_state_dict(saved_gen['state_dict'], strict=False)
44+
def _remove_weight_norm(m):
45+
try:
46+
torch.nn.utils.remove_weight_norm(m)
47+
except ValueError: # this module didn't have weight norm
48+
return
49+
model.apply(_remove_weight_norm)
50+
51+
52+
print("dumping model...")
53+
quantize_model=args.quantize
54+
55+
output_folder = args.export_folder
56+
os.makedirs(output_folder, exist_ok=True)
57+
58+
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name, add_typedef=True)
59+
60+
for name, module in model.named_modules():
61+
62+
if quantize_model:
63+
quantize=name not in unquantized
64+
scale = None if quantize else 1/128
65+
else:
66+
quantize=False
67+
scale=1/128
68+
69+
if isinstance(module, nn.Linear):
70+
print(f"dumping linear layer {name}...")
71+
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
72+
73+
elif isinstance(module, nn.Conv1d):
74+
print(f"dumping conv1d layer {name}...")
75+
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
76+
77+
elif isinstance(module, nn.GRU):
78+
print(f"dumping GRU layer {name}...")
79+
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale, input_sparse=True, recurrent_sparse=True)
80+
81+
elif isinstance(module, nn.GRUCell):
82+
print(f"dumping GRUCell layer {name}...")
83+
wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
84+
85+
elif isinstance(module, nn.Embedding):
86+
print(f"dumping Embedding layer {name}...")
87+
wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
88+
#wexchange.torch.dump_torch_embedding_weights(writer, module)
89+
90+
else:
91+
print(f"Ignoring layer {name}...")
92+
93+
writer.close()

0 commit comments

Comments
 (0)