|
| 1 | +# coding=utf8 |
| 2 | + |
| 3 | +import os |
| 4 | +import sys |
| 5 | + |
| 6 | +import librosa |
| 7 | + |
| 8 | +from inference.svs.opencpop.map import cpop_pinyin2ph_func |
| 9 | +from utils.audio import save_wav |
| 10 | +from utils.hparams import set_hparams, hparams |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +from pypinyin import pinyin, lazy_pinyin, Style |
| 15 | + |
| 16 | +import onnxruntime as ort |
| 17 | + |
| 18 | +from tqdm import tqdm |
| 19 | + |
| 20 | +from utils.text_encoder import TokenTextEncoder |
| 21 | + |
| 22 | +root_dir = os.path.dirname(os.path.abspath(__file__)) |
| 23 | +os.environ['PYTHONPATH'] = f'"{root_dir}"' |
| 24 | + |
| 25 | +sys.argv = [ |
| 26 | + f'{root_dir}/inference/svs/ds_e2e.py', |
| 27 | + '--config', |
| 28 | + f'{root_dir}/usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml', |
| 29 | + '--exp_name', |
| 30 | + '0228_opencpop_ds100_rel' |
| 31 | +] |
| 32 | + |
| 33 | +spec_max = None |
| 34 | +spec_min = None |
| 35 | + |
| 36 | + |
| 37 | +def denorm_spec(x): |
| 38 | + return (x + 1) / 2 * (spec_max - spec_min) + spec_min |
| 39 | + |
| 40 | + |
| 41 | +class TestAllInfer: |
| 42 | + def __init__(self, hparams): |
| 43 | + self.hparams = hparams |
| 44 | + |
| 45 | + phone_list = ["AP", "SP", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", |
| 46 | + "g", |
| 47 | + "h", "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iu", "j", "k", "l", "m", "n", |
| 48 | + "o", |
| 49 | + "ong", "ou", "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "ui", "un", "uo", |
| 50 | + "v", |
| 51 | + "van", "ve", "vn", "w", "x", "y", "z", "zh"] |
| 52 | + self.ph_encoder = TokenTextEncoder( |
| 53 | + None, vocab_list=phone_list, replace_oov=',') |
| 54 | + self.pinyin2phs = cpop_pinyin2ph_func() |
| 55 | + self.spk_map = {'opencpop': 0} |
| 56 | + |
| 57 | + print("load pe") |
| 58 | + self.pe2 = ort.InferenceSession("xiaoma_pe.onnx", providers=["CUDAExecutionProvider"]) |
| 59 | + print("load hifigan") |
| 60 | + self.vocoder2 = ort.InferenceSession("hifigan.onnx", providers=["CUDAExecutionProvider"]) |
| 61 | + print("load singer_fs") |
| 62 | + self.model2 = ort.InferenceSession("singer_fs.onnx", providers=["CUDAExecutionProvider"]) |
| 63 | + ips = self.model2.get_inputs() |
| 64 | + print(len(ips)) |
| 65 | + for i in range(0, len(ips)): |
| 66 | + print(f'{i}. {ips[i].name}') |
| 67 | + |
| 68 | + print("load singer_denoise") |
| 69 | + self.model3 = ort.InferenceSession("singer_denoise.onnx", providers=["CUDAExecutionProvider"]) |
| 70 | + ips = self.model3.get_inputs() |
| 71 | + print(len(ips)) |
| 72 | + for i in range(0, len(ips)): |
| 73 | + print(f'{i}. {ips[i].name}') |
| 74 | + |
| 75 | + print("load over") |
| 76 | + |
| 77 | + def run_vocoder(self, c, **kwargs): |
| 78 | + # c = c.transpose(2, 1) # [B, 80, T] |
| 79 | + c = np.transpose(c, (0, 2, 1)) |
| 80 | + f0 = kwargs.get('f0') # [B, T] |
| 81 | + |
| 82 | + if f0 is not None and hparams.get('use_nsf'): |
| 83 | + ort_inputs = { |
| 84 | + 'x': c, |
| 85 | + 'f0': f0 |
| 86 | + } |
| 87 | + else: |
| 88 | + ort_inputs = { |
| 89 | + 'x': c, |
| 90 | + 'f0': {} |
| 91 | + } |
| 92 | + # [T] |
| 93 | + |
| 94 | + ort_out = self.vocoder2.run(None, ort_inputs) |
| 95 | + y = ort_out[0] |
| 96 | + |
| 97 | + return y[None] |
| 98 | + |
| 99 | + def preprocess_word_level_input(self, inp): |
| 100 | + # Pypinyin can't solve polyphonic words |
| 101 | + text_raw = inp['text'].replace('最长', '最常').replace('长睫毛', '常睫毛') \ |
| 102 | + .replace('那么长', '那么常').replace('多长', '多常') \ |
| 103 | + .replace('很长', '很常') # We hope someone could provide a better g2p module for us by opening pull requests. |
| 104 | + |
| 105 | + # lyric |
| 106 | + pinyins = lazy_pinyin(text_raw, strict=False) |
| 107 | + ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs] |
| 108 | + |
| 109 | + # Note |
| 110 | + note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != ''] |
| 111 | + mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != ''] |
| 112 | + |
| 113 | + if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst): |
| 114 | + print('Pass word-notes check.') |
| 115 | + else: |
| 116 | + print('The number of words does\'t match the number of notes\' windows. ', |
| 117 | + 'You should split the note(s) for each word by | mark.') |
| 118 | + print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst) |
| 119 | + print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst)) |
| 120 | + return None |
| 121 | + |
| 122 | + note_lst = [] |
| 123 | + ph_lst = [] |
| 124 | + midi_dur_lst = [] |
| 125 | + is_slur = [] |
| 126 | + for idx, ph_per_word in enumerate(ph_per_word_lst): |
| 127 | + # for phs in one word: |
| 128 | + # single ph like ['ai'] or multiple phs like ['n', 'i'] |
| 129 | + ph_in_this_word = ph_per_word.split() |
| 130 | + |
| 131 | + # for notes in one word: |
| 132 | + # single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here. |
| 133 | + note_in_this_word = note_per_word_lst[idx].split() |
| 134 | + midi_dur_in_this_word = mididur_per_word_lst[idx].split() |
| 135 | + # process for the model input |
| 136 | + # Step 1. |
| 137 | + # Deal with note of 'not slur' case or the first note of 'slur' case |
| 138 | + # j ie |
| 139 | + # F#4/Gb4 F#4/Gb4 |
| 140 | + # 0 0 |
| 141 | + for ph in ph_in_this_word: |
| 142 | + ph_lst.append(ph) |
| 143 | + note_lst.append(note_in_this_word[0]) |
| 144 | + midi_dur_lst.append(midi_dur_in_this_word[0]) |
| 145 | + is_slur.append(0) |
| 146 | + # step 2. |
| 147 | + # Deal with the 2nd, 3rd... notes of 'slur' case |
| 148 | + # j ie ie |
| 149 | + # F#4/Gb4 F#4/Gb4 C#4/Db4 |
| 150 | + # 0 0 1 |
| 151 | + if len(note_in_this_word) > 1: # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes. |
| 152 | + for idx in range(1, len(note_in_this_word)): |
| 153 | + ph_lst.append(ph_in_this_word[-1]) |
| 154 | + note_lst.append(note_in_this_word[idx]) |
| 155 | + midi_dur_lst.append(midi_dur_in_this_word[idx]) |
| 156 | + is_slur.append(1) |
| 157 | + ph_seq = ' '.join(ph_lst) |
| 158 | + |
| 159 | + if len(ph_lst) == len(note_lst) == len(midi_dur_lst): |
| 160 | + print(len(ph_lst), len(note_lst), len(midi_dur_lst)) |
| 161 | + print('Pass word-notes check.') |
| 162 | + else: |
| 163 | + print('The number of words does\'t match the number of notes\' windows. ', |
| 164 | + 'You should split the note(s) for each word by | mark.') |
| 165 | + return None |
| 166 | + return ph_seq, note_lst, midi_dur_lst, is_slur |
| 167 | + |
| 168 | + def preprocess_phoneme_level_input(self, inp): |
| 169 | + ph_seq = inp['ph_seq'] |
| 170 | + note_lst = inp['note_seq'].split() |
| 171 | + midi_dur_lst = inp['note_dur_seq'].split() |
| 172 | + is_slur = np.array(inp['is_slur_seq'].split(), 'float') |
| 173 | + ph_dur = None |
| 174 | + if inp['ph_dur'] is not None: |
| 175 | + ph_dur = np.array(inp['ph_dur'].split(), 'float') |
| 176 | + print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst), len(ph_dur)) |
| 177 | + if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst) == len(ph_dur): |
| 178 | + print('Pass word-notes check.') |
| 179 | + else: |
| 180 | + print('The number of words does\'t match the number of notes\' windows. ', |
| 181 | + 'You should split the note(s) for each word by | mark.') |
| 182 | + return None |
| 183 | + else: |
| 184 | + print('Automatic phone duration mode') |
| 185 | + print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst)) |
| 186 | + if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst): |
| 187 | + print('Pass word-notes check.') |
| 188 | + else: |
| 189 | + print('The number of words does\'t match the number of notes\' windows. ', |
| 190 | + 'You should split the note(s) for each word by | mark.') |
| 191 | + return None |
| 192 | + return ph_seq, note_lst, midi_dur_lst, is_slur, ph_dur |
| 193 | + |
| 194 | + def preprocess_input(self, inp, input_type='word'): |
| 195 | + """ |
| 196 | +
|
| 197 | + :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} |
| 198 | + :return: |
| 199 | + """ |
| 200 | + |
| 201 | + item_name = inp.get('item_name', '<ITEM_NAME>') |
| 202 | + spk_name = inp.get('spk_name', 'opencpop') |
| 203 | + |
| 204 | + # single spk |
| 205 | + spk_id = self.spk_map[spk_name] |
| 206 | + |
| 207 | + # get ph seq, note lst, midi dur lst, is slur lst. |
| 208 | + if input_type == 'word': |
| 209 | + ret = self.preprocess_word_level_input(inp) |
| 210 | + elif input_type == 'phoneme': # like transcriptions.txt in Opencpop dataset. |
| 211 | + ret = self.preprocess_phoneme_level_input(inp) |
| 212 | + else: |
| 213 | + print('Invalid input type.') |
| 214 | + return None |
| 215 | + |
| 216 | + if ret: |
| 217 | + if input_type == 'word': |
| 218 | + ph_seq, note_lst, midi_dur_lst, is_slur = ret |
| 219 | + else: |
| 220 | + ph_seq, note_lst, midi_dur_lst, is_slur, ph_dur = ret |
| 221 | + else: |
| 222 | + print('==========> Preprocess_word_level or phone_level input wrong.') |
| 223 | + return None |
| 224 | + |
| 225 | + # convert note lst to midi id; convert note dur lst to midi duration |
| 226 | + try: |
| 227 | + midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0 |
| 228 | + for x in note_lst] |
| 229 | + midi_dur_lst = [float(x) for x in midi_dur_lst] |
| 230 | + except Exception as e: |
| 231 | + print(e) |
| 232 | + print('Invalid Input Type.') |
| 233 | + return None |
| 234 | + |
| 235 | + ph_token = self.ph_encoder.encode(ph_seq) |
| 236 | + item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id, |
| 237 | + 'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst), |
| 238 | + 'is_slur': np.asarray(is_slur), 'ph_dur': None} |
| 239 | + item['ph_len'] = len(item['ph_token']) |
| 240 | + if input_type == 'phoneme': |
| 241 | + item['ph_dur'] = ph_dur |
| 242 | + return item |
| 243 | + |
| 244 | + def input_to_batch(self, item): |
| 245 | + item_names = [item['item_name']] |
| 246 | + text = [item['text']] |
| 247 | + ph = [item['ph']] |
| 248 | + txt_tokens = np.array(item['ph_token'], np.int64)[None, :] |
| 249 | + # txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) |
| 250 | + txt_lengths = np.array([txt_tokens.shape[1]], np.int64) |
| 251 | + # txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) |
| 252 | + spk_ids = np.zeros(item['spk_id'], np.int64)[None, :] |
| 253 | + # spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device) |
| 254 | + |
| 255 | + pitch_midi = np.array(item['pitch_midi'], np.int64)[None, :hparams['max_frames']] |
| 256 | + # pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device) |
| 257 | + midi_dur = np.array(item['midi_dur'], np.float32)[None, :hparams['max_frames']] |
| 258 | + # midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device) |
| 259 | + is_slur = np.array(item['is_slur'], np.int64)[None, :hparams['max_frames']] |
| 260 | + # is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device) |
| 261 | + mel2ph = None |
| 262 | + |
| 263 | + # if item['ph_dur'] is not None: |
| 264 | + # ph_acc = np.around(np.add.accumulate(24000 * item['ph_dur'] / 128)).astype('int') |
| 265 | + # ph_dur = np.diff(ph_acc, prepend=0) |
| 266 | + # ph_dur = np.array(ph_dur, np.int64)[None, :hparams['max_frames']] |
| 267 | + # lr = LengthRegulator() |
| 268 | + # mel2ph = lr(ph_dur, txt_tokens == 0).detach() |
| 269 | + |
| 270 | + batch = { |
| 271 | + 'item_name': item_names, |
| 272 | + 'text': text, |
| 273 | + 'ph': ph, |
| 274 | + 'txt_tokens': txt_tokens, |
| 275 | + 'txt_lengths': txt_lengths, |
| 276 | + 'spk_ids': spk_ids, |
| 277 | + 'pitch_midi': pitch_midi, |
| 278 | + 'midi_dur': midi_dur, |
| 279 | + 'is_slur': is_slur, |
| 280 | + 'mel2ph': mel2ph |
| 281 | + } |
| 282 | + return batch |
| 283 | + |
| 284 | + def forward_model(self, inp): |
| 285 | + sample = self.input_to_batch(inp) |
| 286 | + txt_tokens = sample['txt_tokens'] # [B, T_t] |
| 287 | + spk_id = sample.get('spk_ids') |
| 288 | + mel2ph = sample['mel2ph'] |
| 289 | + |
| 290 | + decoder_inp = self.model2.run( |
| 291 | + None, |
| 292 | + { |
| 293 | + "txt_tokens": txt_tokens, |
| 294 | + # "spk_id": spk_id, |
| 295 | + "pitch_midi": sample['pitch_midi'], |
| 296 | + "midi_dur": sample['midi_dur'], |
| 297 | + "is_slur": sample['is_slur'], |
| 298 | + # "mel2ph": np.array([0, 0]).astype(np.int64) |
| 299 | + } |
| 300 | + ) |
| 301 | + cond = np.transpose(decoder_inp[0], (0, 2, 1)) |
| 302 | + # cond = torch.from_numpy(decoder_inp[0]).transpose(1, 2) |
| 303 | + |
| 304 | + t = hparams['K_step'] |
| 305 | + print('===> gaussion start.') |
| 306 | + shape = (cond.shape[0], 1, |
| 307 | + hparams['audio_num_mel_bins'], cond.shape[2]) |
| 308 | + # x = torch.randn(shape) |
| 309 | + # x = torch.zeros(shape, device=device) |
| 310 | + x = np.random.randn(*shape).astype(np.float32) |
| 311 | + |
| 312 | + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): |
| 313 | + res2 = self.model3.run( |
| 314 | + None, |
| 315 | + { |
| 316 | + "x": x, |
| 317 | + "t": np.array([i]).astype(np.int64), |
| 318 | + "cond": cond, |
| 319 | + } |
| 320 | + ) |
| 321 | + x = res2[0] |
| 322 | + |
| 323 | + # x = x[:, 0].transpose(1, 2) |
| 324 | + x = np.transpose(x[:, 0], (0, 2, 1)) |
| 325 | + |
| 326 | + if mel2ph is not None: # for singing |
| 327 | + mel_out = denorm_spec(x) * ((mel2ph > 0).float()[:, :, None]) |
| 328 | + else: |
| 329 | + mel_out = denorm_spec(x) |
| 330 | + |
| 331 | + # mel_out = output['mel_out'] # [B, T,80] |
| 332 | + |
| 333 | + if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
| 334 | + pe2_res = self.pe2.run(None, |
| 335 | + { |
| 336 | + 'mel_input': mel_out |
| 337 | + } |
| 338 | + ) |
| 339 | + |
| 340 | + # pe predict from Pred mel |
| 341 | + f0_pred = pe2_res[1] |
| 342 | + else: |
| 343 | + # f0_pred = output['f0_denorm'] |
| 344 | + f0_pred = None |
| 345 | + |
| 346 | + # Run Vocoder |
| 347 | + wav_out = self.run_vocoder(mel_out, f0=f0_pred) |
| 348 | + # wav_out = wav_out.cpu().numpy() |
| 349 | + return wav_out[0] |
| 350 | + |
| 351 | + def postprocess_output(self, output): |
| 352 | + return output |
| 353 | + |
| 354 | + def infer_once(self, inp): |
| 355 | + inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word') |
| 356 | + output = self.forward_model(inp) |
| 357 | + output = self.postprocess_output(output) |
| 358 | + return output |
| 359 | + |
| 360 | +if __name__ == '__main__': |
| 361 | + c = { |
| 362 | + 'text': '小酒窝长睫毛AP是你最美的记号', |
| 363 | + 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4', |
| 364 | + 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340', |
| 365 | + 'input_type': 'word' |
| 366 | + } # user input: Chinese characters |
| 367 | + |
| 368 | + target = "./infer_out/onnx_test_singer_res.wav" |
| 369 | + |
| 370 | + set_hparams(print_hparams=False) |
| 371 | + |
| 372 | + spec_min = np.array(hparams['spec_min'], np.float32)[None, None, :hparams['keep_bins']] |
| 373 | + spec_max = np.array(hparams['spec_max'], np.float32)[None, None, :hparams['keep_bins']] |
| 374 | + |
| 375 | + infer_ins = TestAllInfer(hparams) |
| 376 | + |
| 377 | + out = infer_ins.infer_once(c) |
| 378 | + os.makedirs(os.path.dirname(target), exist_ok=True) |
| 379 | + print(f'| save audio: {target}') |
| 380 | + save_wav(out, target, hparams['audio_sample_rate']) |
| 381 | + |
| 382 | + print("OK") |
0 commit comments