Skip to content

Commit c103210

Browse files
committed
Add
1 parent 01fc1d8 commit c103210

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

bin/publish.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
import os
3+
import argparse
4+
import random
5+
import yaml
6+
import numpy as np
7+
import hparams as hp
8+
9+
from data.audio import save_wav, inv_mel_spectrogram
10+
from model.generator import MelGANGenerator
11+
from model.generator import MultiBandHiFiGANGenerator
12+
from model.generator import HiFiGANGenerator
13+
from model.generator import BasisMelGANGenerator
14+
15+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16+
17+
18+
def publish_model(checkpoint_path, config_path, model_name, save_path):
19+
with open(config_path) as f:
20+
config = yaml.load(f, Loader=yaml.Loader)
21+
print(f"Loading Model of {model_name}...")
22+
if model_name == "melgan":
23+
model = MelGANGenerator(in_channels=config["in_channels"],
24+
out_channels=config["out_channels"],
25+
kernel_size=config["kernel_size"],
26+
channels=config["channels"],
27+
upsample_scales=config["upsample_scales"],
28+
stack_kernel_size=config["stack_kernel_size"],
29+
stacks=config["stacks"],
30+
use_weight_norm=config["use_weight_norm"],
31+
use_causal_conv=config["use_causal_conv"]).to(device)
32+
elif model_name == "hifigan":
33+
model = HiFiGANGenerator(resblock_kernel_sizes=config["resblock_kernel_sizes"],
34+
upsample_rates=config["upsample_rates"],
35+
upsample_initial_channel=config["upsample_initial_channel"],
36+
resblock_type=config["resblock_type"],
37+
upsample_kernel_sizes=config["upsample_kernel_sizes"],
38+
resblock_dilation_sizes=config["resblock_dilation_sizes"],
39+
transposedconv=config["transposedconv"],
40+
bias=config["bias"]).to(device)
41+
elif model_name == "multiband-hifigan":
42+
model = MultiBandHiFiGANGenerator(resblock_kernel_sizes=config["resblock_kernel_sizes"],
43+
upsample_rates=config["upsample_rates"],
44+
upsample_initial_channel=config["upsample_initial_channel"],
45+
resblock_type=config["resblock_type"],
46+
upsample_kernel_sizes=config["upsample_kernel_sizes"],
47+
resblock_dilation_sizes=config["resblock_dilation_sizes"],
48+
transposedconv=config["transposedconv"],
49+
bias=config["bias"]).to(device)
50+
elif model_name == "basis-melgan":
51+
basis_signal_weight = torch.zeros(config["L"], config["out_channels"]).float()
52+
model = BasisMelGANGenerator(basis_signal_weight=basis_signal_weight,
53+
L=config["L"],
54+
in_channels=config["in_channels"],
55+
out_channels=config["out_channels"],
56+
kernel_size=config["kernel_size"],
57+
channels=config["channels"],
58+
upsample_scales=config["upsample_scales"],
59+
stack_kernel_size=config["stack_kernel_size"],
60+
stacks=config["stacks"],
61+
use_weight_norm=config["use_weight_norm"],
62+
use_causal_conv=config["use_causal_conv"],
63+
transposedconv=config["transposedconv"]).to(device)
64+
else:
65+
raise Exception("no model find!")
66+
model.load_state_dict(torch.load(os.path.join(checkpoint_path), map_location=torch.device(device))['model'])
67+
if model_name == "basis-melgan":
68+
with torch.no_grad():
69+
bias = model.inference(torch.zeros(30000, 80)) # support up to synthesize 300s waveform
70+
pattern = bias.cpu().numpy()
71+
published_dict = {
72+
'model': model.state_dict(),
73+
'pattern': pattern
74+
}
75+
torch.save(published_dict, save_path)
76+
model.eval()
77+
model.remove_weight_norm()
78+
return
79+
80+
81+
def run_publisher():
82+
parser = argparse.ArgumentParser()
83+
parser.add_argument('--checkpoint_path', type=str)
84+
parser.add_argument("--model_name", type=str, help="melgan, hifigan and multiband-hifigan.")
85+
parser.add_argument("--config", type=str, help="path to model configuration file")
86+
parser.add_argument("--save_path", type=str, help="path to save published model")
87+
args = parser.parse_args()
88+
publish_model(args.checkpoint_path, args.config, args.model_name, args.save_path)

0 commit comments

Comments
 (0)