|
| 1 | +import torch |
| 2 | +import os |
| 3 | +import argparse |
| 4 | +import random |
| 5 | +import yaml |
| 6 | +import numpy as np |
| 7 | +import hparams as hp |
| 8 | + |
| 9 | +from data.audio import save_wav, inv_mel_spectrogram |
| 10 | +from model.generator import MelGANGenerator |
| 11 | +from model.generator import MultiBandHiFiGANGenerator |
| 12 | +from model.generator import HiFiGANGenerator |
| 13 | +from model.generator import BasisMelGANGenerator |
| 14 | + |
| 15 | +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 16 | + |
| 17 | + |
| 18 | +def publish_model(checkpoint_path, config_path, model_name, save_path): |
| 19 | + with open(config_path) as f: |
| 20 | + config = yaml.load(f, Loader=yaml.Loader) |
| 21 | + print(f"Loading Model of {model_name}...") |
| 22 | + if model_name == "melgan": |
| 23 | + model = MelGANGenerator(in_channels=config["in_channels"], |
| 24 | + out_channels=config["out_channels"], |
| 25 | + kernel_size=config["kernel_size"], |
| 26 | + channels=config["channels"], |
| 27 | + upsample_scales=config["upsample_scales"], |
| 28 | + stack_kernel_size=config["stack_kernel_size"], |
| 29 | + stacks=config["stacks"], |
| 30 | + use_weight_norm=config["use_weight_norm"], |
| 31 | + use_causal_conv=config["use_causal_conv"]).to(device) |
| 32 | + elif model_name == "hifigan": |
| 33 | + model = HiFiGANGenerator(resblock_kernel_sizes=config["resblock_kernel_sizes"], |
| 34 | + upsample_rates=config["upsample_rates"], |
| 35 | + upsample_initial_channel=config["upsample_initial_channel"], |
| 36 | + resblock_type=config["resblock_type"], |
| 37 | + upsample_kernel_sizes=config["upsample_kernel_sizes"], |
| 38 | + resblock_dilation_sizes=config["resblock_dilation_sizes"], |
| 39 | + transposedconv=config["transposedconv"], |
| 40 | + bias=config["bias"]).to(device) |
| 41 | + elif model_name == "multiband-hifigan": |
| 42 | + model = MultiBandHiFiGANGenerator(resblock_kernel_sizes=config["resblock_kernel_sizes"], |
| 43 | + upsample_rates=config["upsample_rates"], |
| 44 | + upsample_initial_channel=config["upsample_initial_channel"], |
| 45 | + resblock_type=config["resblock_type"], |
| 46 | + upsample_kernel_sizes=config["upsample_kernel_sizes"], |
| 47 | + resblock_dilation_sizes=config["resblock_dilation_sizes"], |
| 48 | + transposedconv=config["transposedconv"], |
| 49 | + bias=config["bias"]).to(device) |
| 50 | + elif model_name == "basis-melgan": |
| 51 | + basis_signal_weight = torch.zeros(config["L"], config["out_channels"]).float() |
| 52 | + model = BasisMelGANGenerator(basis_signal_weight=basis_signal_weight, |
| 53 | + L=config["L"], |
| 54 | + in_channels=config["in_channels"], |
| 55 | + out_channels=config["out_channels"], |
| 56 | + kernel_size=config["kernel_size"], |
| 57 | + channels=config["channels"], |
| 58 | + upsample_scales=config["upsample_scales"], |
| 59 | + stack_kernel_size=config["stack_kernel_size"], |
| 60 | + stacks=config["stacks"], |
| 61 | + use_weight_norm=config["use_weight_norm"], |
| 62 | + use_causal_conv=config["use_causal_conv"], |
| 63 | + transposedconv=config["transposedconv"]).to(device) |
| 64 | + else: |
| 65 | + raise Exception("no model find!") |
| 66 | + model.load_state_dict(torch.load(os.path.join(checkpoint_path), map_location=torch.device(device))['model']) |
| 67 | + if model_name == "basis-melgan": |
| 68 | + with torch.no_grad(): |
| 69 | + bias = model.inference(torch.zeros(30000, 80)) # support up to synthesize 300s waveform |
| 70 | + pattern = bias.cpu().numpy() |
| 71 | + published_dict = { |
| 72 | + 'model': model.state_dict(), |
| 73 | + 'pattern': pattern |
| 74 | + } |
| 75 | + torch.save(published_dict, save_path) |
| 76 | + model.eval() |
| 77 | + model.remove_weight_norm() |
| 78 | + return |
| 79 | + |
| 80 | + |
| 81 | +def run_publisher(): |
| 82 | + parser = argparse.ArgumentParser() |
| 83 | + parser.add_argument('--checkpoint_path', type=str) |
| 84 | + parser.add_argument("--model_name", type=str, help="melgan, hifigan and multiband-hifigan.") |
| 85 | + parser.add_argument("--config", type=str, help="path to model configuration file") |
| 86 | + parser.add_argument("--save_path", type=str, help="path to save published model") |
| 87 | + args = parser.parse_args() |
| 88 | + publish_model(args.checkpoint_path, args.config, args.model_name, args.save_path) |
0 commit comments