Skip to content

Commit beba3ad

Browse files
Merge branch 'openvpi:master' into master
2 parents 781df74 + 4a43e7a commit beba3ad

File tree

1 file changed

+382
-0
lines changed

1 file changed

+382
-0
lines changed

my_numpy.py

Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
# coding=utf8
2+
3+
import os
4+
import sys
5+
6+
import librosa
7+
8+
from inference.svs.opencpop.map import cpop_pinyin2ph_func
9+
from utils.audio import save_wav
10+
from utils.hparams import set_hparams, hparams
11+
12+
import numpy as np
13+
14+
from pypinyin import pinyin, lazy_pinyin, Style
15+
16+
import onnxruntime as ort
17+
18+
from tqdm import tqdm
19+
20+
from utils.text_encoder import TokenTextEncoder
21+
22+
root_dir = os.path.dirname(os.path.abspath(__file__))
23+
os.environ['PYTHONPATH'] = f'"{root_dir}"'
24+
25+
sys.argv = [
26+
f'{root_dir}/inference/svs/ds_e2e.py',
27+
'--config',
28+
f'{root_dir}/usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml',
29+
'--exp_name',
30+
'0228_opencpop_ds100_rel'
31+
]
32+
33+
spec_max = None
34+
spec_min = None
35+
36+
37+
def denorm_spec(x):
38+
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
39+
40+
41+
class TestAllInfer:
42+
def __init__(self, hparams):
43+
self.hparams = hparams
44+
45+
phone_list = ["AP", "SP", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f",
46+
"g",
47+
"h", "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iu", "j", "k", "l", "m", "n",
48+
"o",
49+
"ong", "ou", "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "ui", "un", "uo",
50+
"v",
51+
"van", "ve", "vn", "w", "x", "y", "z", "zh"]
52+
self.ph_encoder = TokenTextEncoder(
53+
None, vocab_list=phone_list, replace_oov=',')
54+
self.pinyin2phs = cpop_pinyin2ph_func()
55+
self.spk_map = {'opencpop': 0}
56+
57+
print("load pe")
58+
self.pe2 = ort.InferenceSession("xiaoma_pe.onnx", providers=["CUDAExecutionProvider"])
59+
print("load hifigan")
60+
self.vocoder2 = ort.InferenceSession("hifigan.onnx", providers=["CUDAExecutionProvider"])
61+
print("load singer_fs")
62+
self.model2 = ort.InferenceSession("singer_fs.onnx", providers=["CUDAExecutionProvider"])
63+
ips = self.model2.get_inputs()
64+
print(len(ips))
65+
for i in range(0, len(ips)):
66+
print(f'{i}. {ips[i].name}')
67+
68+
print("load singer_denoise")
69+
self.model3 = ort.InferenceSession("singer_denoise.onnx", providers=["CUDAExecutionProvider"])
70+
ips = self.model3.get_inputs()
71+
print(len(ips))
72+
for i in range(0, len(ips)):
73+
print(f'{i}. {ips[i].name}')
74+
75+
print("load over")
76+
77+
def run_vocoder(self, c, **kwargs):
78+
# c = c.transpose(2, 1) # [B, 80, T]
79+
c = np.transpose(c, (0, 2, 1))
80+
f0 = kwargs.get('f0') # [B, T]
81+
82+
if f0 is not None and hparams.get('use_nsf'):
83+
ort_inputs = {
84+
'x': c,
85+
'f0': f0
86+
}
87+
else:
88+
ort_inputs = {
89+
'x': c,
90+
'f0': {}
91+
}
92+
# [T]
93+
94+
ort_out = self.vocoder2.run(None, ort_inputs)
95+
y = ort_out[0]
96+
97+
return y[None]
98+
99+
def preprocess_word_level_input(self, inp):
100+
# Pypinyin can't solve polyphonic words
101+
text_raw = inp['text'].replace('最长', '最常').replace('长睫毛', '常睫毛') \
102+
.replace('那么长', '那么常').replace('多长', '多常') \
103+
.replace('很长', '很常') # We hope someone could provide a better g2p module for us by opening pull requests.
104+
105+
# lyric
106+
pinyins = lazy_pinyin(text_raw, strict=False)
107+
ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs]
108+
109+
# Note
110+
note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != '']
111+
mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != '']
112+
113+
if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst):
114+
print('Pass word-notes check.')
115+
else:
116+
print('The number of words does\'t match the number of notes\' windows. ',
117+
'You should split the note(s) for each word by | mark.')
118+
print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
119+
print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst))
120+
return None
121+
122+
note_lst = []
123+
ph_lst = []
124+
midi_dur_lst = []
125+
is_slur = []
126+
for idx, ph_per_word in enumerate(ph_per_word_lst):
127+
# for phs in one word:
128+
# single ph like ['ai'] or multiple phs like ['n', 'i']
129+
ph_in_this_word = ph_per_word.split()
130+
131+
# for notes in one word:
132+
# single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here.
133+
note_in_this_word = note_per_word_lst[idx].split()
134+
midi_dur_in_this_word = mididur_per_word_lst[idx].split()
135+
# process for the model input
136+
# Step 1.
137+
# Deal with note of 'not slur' case or the first note of 'slur' case
138+
# j ie
139+
# F#4/Gb4 F#4/Gb4
140+
# 0 0
141+
for ph in ph_in_this_word:
142+
ph_lst.append(ph)
143+
note_lst.append(note_in_this_word[0])
144+
midi_dur_lst.append(midi_dur_in_this_word[0])
145+
is_slur.append(0)
146+
# step 2.
147+
# Deal with the 2nd, 3rd... notes of 'slur' case
148+
# j ie ie
149+
# F#4/Gb4 F#4/Gb4 C#4/Db4
150+
# 0 0 1
151+
if len(note_in_this_word) > 1: # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes.
152+
for idx in range(1, len(note_in_this_word)):
153+
ph_lst.append(ph_in_this_word[-1])
154+
note_lst.append(note_in_this_word[idx])
155+
midi_dur_lst.append(midi_dur_in_this_word[idx])
156+
is_slur.append(1)
157+
ph_seq = ' '.join(ph_lst)
158+
159+
if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
160+
print(len(ph_lst), len(note_lst), len(midi_dur_lst))
161+
print('Pass word-notes check.')
162+
else:
163+
print('The number of words does\'t match the number of notes\' windows. ',
164+
'You should split the note(s) for each word by | mark.')
165+
return None
166+
return ph_seq, note_lst, midi_dur_lst, is_slur
167+
168+
def preprocess_phoneme_level_input(self, inp):
169+
ph_seq = inp['ph_seq']
170+
note_lst = inp['note_seq'].split()
171+
midi_dur_lst = inp['note_dur_seq'].split()
172+
is_slur = np.array(inp['is_slur_seq'].split(), 'float')
173+
ph_dur = None
174+
if inp['ph_dur'] is not None:
175+
ph_dur = np.array(inp['ph_dur'].split(), 'float')
176+
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst), len(ph_dur))
177+
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst) == len(ph_dur):
178+
print('Pass word-notes check.')
179+
else:
180+
print('The number of words does\'t match the number of notes\' windows. ',
181+
'You should split the note(s) for each word by | mark.')
182+
return None
183+
else:
184+
print('Automatic phone duration mode')
185+
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
186+
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
187+
print('Pass word-notes check.')
188+
else:
189+
print('The number of words does\'t match the number of notes\' windows. ',
190+
'You should split the note(s) for each word by | mark.')
191+
return None
192+
return ph_seq, note_lst, midi_dur_lst, is_slur, ph_dur
193+
194+
def preprocess_input(self, inp, input_type='word'):
195+
"""
196+
197+
:param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
198+
:return:
199+
"""
200+
201+
item_name = inp.get('item_name', '<ITEM_NAME>')
202+
spk_name = inp.get('spk_name', 'opencpop')
203+
204+
# single spk
205+
spk_id = self.spk_map[spk_name]
206+
207+
# get ph seq, note lst, midi dur lst, is slur lst.
208+
if input_type == 'word':
209+
ret = self.preprocess_word_level_input(inp)
210+
elif input_type == 'phoneme': # like transcriptions.txt in Opencpop dataset.
211+
ret = self.preprocess_phoneme_level_input(inp)
212+
else:
213+
print('Invalid input type.')
214+
return None
215+
216+
if ret:
217+
if input_type == 'word':
218+
ph_seq, note_lst, midi_dur_lst, is_slur = ret
219+
else:
220+
ph_seq, note_lst, midi_dur_lst, is_slur, ph_dur = ret
221+
else:
222+
print('==========> Preprocess_word_level or phone_level input wrong.')
223+
return None
224+
225+
# convert note lst to midi id; convert note dur lst to midi duration
226+
try:
227+
midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
228+
for x in note_lst]
229+
midi_dur_lst = [float(x) for x in midi_dur_lst]
230+
except Exception as e:
231+
print(e)
232+
print('Invalid Input Type.')
233+
return None
234+
235+
ph_token = self.ph_encoder.encode(ph_seq)
236+
item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
237+
'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
238+
'is_slur': np.asarray(is_slur), 'ph_dur': None}
239+
item['ph_len'] = len(item['ph_token'])
240+
if input_type == 'phoneme':
241+
item['ph_dur'] = ph_dur
242+
return item
243+
244+
def input_to_batch(self, item):
245+
item_names = [item['item_name']]
246+
text = [item['text']]
247+
ph = [item['ph']]
248+
txt_tokens = np.array(item['ph_token'], np.int64)[None, :]
249+
# txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
250+
txt_lengths = np.array([txt_tokens.shape[1]], np.int64)
251+
# txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
252+
spk_ids = np.zeros(item['spk_id'], np.int64)[None, :]
253+
# spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device)
254+
255+
pitch_midi = np.array(item['pitch_midi'], np.int64)[None, :hparams['max_frames']]
256+
# pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device)
257+
midi_dur = np.array(item['midi_dur'], np.float32)[None, :hparams['max_frames']]
258+
# midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device)
259+
is_slur = np.array(item['is_slur'], np.int64)[None, :hparams['max_frames']]
260+
# is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device)
261+
mel2ph = None
262+
263+
# if item['ph_dur'] is not None:
264+
# ph_acc = np.around(np.add.accumulate(24000 * item['ph_dur'] / 128)).astype('int')
265+
# ph_dur = np.diff(ph_acc, prepend=0)
266+
# ph_dur = np.array(ph_dur, np.int64)[None, :hparams['max_frames']]
267+
# lr = LengthRegulator()
268+
# mel2ph = lr(ph_dur, txt_tokens == 0).detach()
269+
270+
batch = {
271+
'item_name': item_names,
272+
'text': text,
273+
'ph': ph,
274+
'txt_tokens': txt_tokens,
275+
'txt_lengths': txt_lengths,
276+
'spk_ids': spk_ids,
277+
'pitch_midi': pitch_midi,
278+
'midi_dur': midi_dur,
279+
'is_slur': is_slur,
280+
'mel2ph': mel2ph
281+
}
282+
return batch
283+
284+
def forward_model(self, inp):
285+
sample = self.input_to_batch(inp)
286+
txt_tokens = sample['txt_tokens'] # [B, T_t]
287+
spk_id = sample.get('spk_ids')
288+
mel2ph = sample['mel2ph']
289+
290+
decoder_inp = self.model2.run(
291+
None,
292+
{
293+
"txt_tokens": txt_tokens,
294+
# "spk_id": spk_id,
295+
"pitch_midi": sample['pitch_midi'],
296+
"midi_dur": sample['midi_dur'],
297+
"is_slur": sample['is_slur'],
298+
# "mel2ph": np.array([0, 0]).astype(np.int64)
299+
}
300+
)
301+
cond = np.transpose(decoder_inp[0], (0, 2, 1))
302+
# cond = torch.from_numpy(decoder_inp[0]).transpose(1, 2)
303+
304+
t = hparams['K_step']
305+
print('===> gaussion start.')
306+
shape = (cond.shape[0], 1,
307+
hparams['audio_num_mel_bins'], cond.shape[2])
308+
# x = torch.randn(shape)
309+
# x = torch.zeros(shape, device=device)
310+
x = np.random.randn(*shape).astype(np.float32)
311+
312+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
313+
res2 = self.model3.run(
314+
None,
315+
{
316+
"x": x,
317+
"t": np.array([i]).astype(np.int64),
318+
"cond": cond,
319+
}
320+
)
321+
x = res2[0]
322+
323+
# x = x[:, 0].transpose(1, 2)
324+
x = np.transpose(x[:, 0], (0, 2, 1))
325+
326+
if mel2ph is not None: # for singing
327+
mel_out = denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
328+
else:
329+
mel_out = denorm_spec(x)
330+
331+
# mel_out = output['mel_out'] # [B, T,80]
332+
333+
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
334+
pe2_res = self.pe2.run(None,
335+
{
336+
'mel_input': mel_out
337+
}
338+
)
339+
340+
# pe predict from Pred mel
341+
f0_pred = pe2_res[1]
342+
else:
343+
# f0_pred = output['f0_denorm']
344+
f0_pred = None
345+
346+
# Run Vocoder
347+
wav_out = self.run_vocoder(mel_out, f0=f0_pred)
348+
# wav_out = wav_out.cpu().numpy()
349+
return wav_out[0]
350+
351+
def postprocess_output(self, output):
352+
return output
353+
354+
def infer_once(self, inp):
355+
inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
356+
output = self.forward_model(inp)
357+
output = self.postprocess_output(output)
358+
return output
359+
360+
if __name__ == '__main__':
361+
c = {
362+
'text': '小酒窝长睫毛AP是你最美的记号',
363+
'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
364+
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
365+
'input_type': 'word'
366+
} # user input: Chinese characters
367+
368+
target = "./infer_out/onnx_test_singer_res.wav"
369+
370+
set_hparams(print_hparams=False)
371+
372+
spec_min = np.array(hparams['spec_min'], np.float32)[None, None, :hparams['keep_bins']]
373+
spec_max = np.array(hparams['spec_max'], np.float32)[None, None, :hparams['keep_bins']]
374+
375+
infer_ins = TestAllInfer(hparams)
376+
377+
out = infer_ins.infer_once(c)
378+
os.makedirs(os.path.dirname(target), exist_ok=True)
379+
print(f'| save audio: {target}')
380+
save_wav(out, target, hparams['audio_sample_rate'])
381+
382+
print("OK")

0 commit comments

Comments
 (0)