Skip to content

Commit 3810490

Browse files
committed
DPO tweaks and notebook changes
1 parent 19b771c commit 3810490

File tree

2 files changed

+77
-88
lines changed

2 files changed

+77
-88
lines changed

scripts/t5tts/dpo/create_text_contextpairs.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,27 @@ def main():
3333
parser.add_argument("--output_manifest", type=str)
3434
args = parser.parse_args()
3535

36-
with open(args.challenging_texts, 'r') as f:
37-
challenging_texts = f.readlines()
38-
challenging_texts = [text.strip() for text in challenging_texts if text.strip() != '']
39-
40-
with open(args.regular_texts_for_audiocontext, 'r') as f:
41-
regular_texts_for_audiocontext = f.readlines()
42-
regular_texts_for_audiocontext = [text.strip() for text in regular_texts_for_audiocontext if text.strip() != '']
43-
44-
with open(args.regular_texts_for_textcontext, 'r') as f:
45-
regular_texts_for_textcontext = f.readlines()
46-
regular_texts_for_textcontext = [text.strip() for text in regular_texts_for_textcontext if text.strip() != '']
36+
if args.challenging_texts is not None:
37+
with open(args.challenging_texts, 'r') as f:
38+
challenging_texts = f.readlines()
39+
challenging_texts = [text.strip() for text in challenging_texts if text.strip() != '']
40+
else:
41+
challenging_texts = None
42+
43+
if args.regular_texts_for_audiocontext is not None:
44+
with open(args.regular_texts_for_audiocontext, 'r') as f:
45+
regular_texts_for_audiocontext = f.readlines()
46+
regular_texts_for_audiocontext = [text.strip() for text in regular_texts_for_audiocontext if text.strip() != '']
47+
else:
48+
regular_texts_for_audiocontext = []
4749

50+
if args.regular_texts_for_textcontext is not None:
51+
with open(args.regular_texts_for_textcontext, 'r') as f:
52+
regular_texts_for_textcontext = f.readlines()
53+
regular_texts_for_textcontext = [text.strip() for text in regular_texts_for_textcontext if text.strip() != '']
54+
else:
55+
regular_texts_for_textcontext = None
56+
4857
with open(args.audio_contexts, 'r') as f:
4958
audio_contexts = f.readlines()
5059
audio_contexts = [json.loads(context.strip()) for context in audio_contexts if context.strip() != '']
@@ -66,18 +75,19 @@ def main():
6675
text_context = random.choice(text_contexts)
6776
record = create_text_context_record(challenging_text, text_context, dummy_audio_filepath, 'challenging', dummy_target_audio_codes_path)
6877
all_records.append(record)
69-
70-
for regular_text in regular_texts_for_audiocontext:
71-
for _ in range(args.n_audio_contexts_per_regular_text):
72-
audio_context = random.choice(audio_contexts)
73-
record = create_audio_context_record(regular_text, audio_context, 'regular')
74-
all_records.append(record)
75-
76-
for regular_text in regular_texts_for_textcontext:
77-
for _ in range(args.n_text_contexts_per_regular_text):
78-
text_context = random.choice(text_contexts)
79-
record = create_text_context_record(regular_text, text_context, dummy_audio_filepath, 'regular', dummy_target_audio_codes_path)
80-
all_records.append(record)
78+
79+
if regular_texts_for_audiocontext is not None:
80+
for regular_text in regular_texts_for_audiocontext:
81+
for _ in range(args.n_audio_contexts_per_regular_text):
82+
audio_context = random.choice(audio_contexts)
83+
record = create_audio_context_record(regular_text, audio_context, 'regular')
84+
all_records.append(record)
85+
if regular_texts_for_textcontext is not None:
86+
for regular_text in regular_texts_for_textcontext:
87+
for _ in range(args.n_text_contexts_per_regular_text):
88+
text_context = random.choice(text_contexts)
89+
record = create_text_context_record(regular_text, text_context, dummy_audio_filepath, 'regular', dummy_target_audio_codes_path)
90+
all_records.append(record)
8191

8292
random.shuffle(all_records)
8393
repeated_records = []

t5tts_inference.ipynb

Lines changed: 44 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,10 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": null,
66
"id": "466ccdc5",
77
"metadata": {},
8-
"outputs": [
9-
{
10-
"name": "stderr",
11-
"output_type": "stream",
12-
"text": [
13-
"[NeMo W 2025-02-03 20:06:44 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14-
" from .autonotebook import tqdm as notebook_tqdm\n",
15-
" \n"
16-
]
17-
}
18-
],
8+
"outputs": [],
199
"source": [
2010
"from nemo.collections.tts.models import T5TTS_Model\n",
2111
"from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset, DatasetSample\n",
@@ -48,8 +38,8 @@
4838
"# Checkpoint and Hparams Paths\n",
4939
"# hparams_file = \"/datap/misc/Experiments/SimpleT5Explore/LocalTraining_LRH/T5TTS/0/hparams.yaml\"\n",
5040
"# checkpoint_file = \"/datap/misc/Experiments/SimpleT5Explore/LocalTraining_LRH/T5TTS/0/checkpoints/test.ckpt\"\n",
41+
"checkpoint_file = \"/datap/misc/continuouscheckpoints/chal/2502_finetune_challenging_LR1e-5_T5TTS--val_loss=5.2330-epoch=0-last.ckpt\" #T5TTS--val_loss=5.8671-epoch=1-last.ckpt\"\n",
5142
"hparams_file = \"/home/rfejgin/release_2502/hparams__final_xform__yt_weight0.25_plus_18k_single_stage_enc3_fixes_phoneme_only.yaml\"\n",
52-
"codecmodel_path = \"/home/rfejgin/release_2502/dpo_fine_tuning_beta0.1__final_xform_enc3_T5TTS--val_loss_0.3899-epoch_24.ckpt\"\n",
5343
"\n",
5444
"# Temp out dir for saving audios\n",
5545
"out_dir = \"/datap/misc/t5tts_inference_notebook_samples\"\n",
@@ -67,48 +57,18 @@
6757
},
6858
{
6959
"cell_type": "code",
70-
"execution_count": 3,
60+
"execution_count": null,
7161
"id": "87bf66f9",
7262
"metadata": {},
73-
"outputs": [
74-
{
75-
"name": "stderr",
76-
"output_type": "stream",
77-
"text": [
78-
"[NeMo W 2025-02-03 20:06:51 experimental:26] `<class 'nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p'>` is experimental and not ready for production yet. Use at your own risk.\n",
79-
"[NeMo W 2025-02-03 20:06:52 i18n_ipa:124] apply_to_oov_word=None, This means that some of words will remain unchanged if they are not handled by any of the rules in self.parse_one_word(). This may be intended if phonemes and chars are both valid inputs, otherwise, you may see unexpected deletions in your input.\n",
80-
"[NeMo W 2025-02-03 20:06:52 experimental:26] `<class 'nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer'>` is experimental and not ready for production yet. Use at your own risk.\n",
81-
"[NeMo W 2025-02-03 20:06:52 experimental:26] `<class 'nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p'>` is experimental and not ready for production yet. Use at your own risk.\n",
82-
"[NeMo W 2025-02-03 20:06:53 i18n_ipa:124] apply_to_oov_word=None, This means that some of words will remain unchanged if they are not handled by any of the rules in self.parse_one_word(). This may be intended if phonemes and chars are both valid inputs, otherwise, you may see unexpected deletions in your input.\n",
83-
"[NeMo W 2025-02-03 20:06:53 experimental:26] `<class 'nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer'>` is experimental and not ready for production yet. Use at your own risk.\n",
84-
"[NeMo W 2025-02-03 20:06:53 experimental:26] `<class 'nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p'>` is experimental and not ready for production yet. Use at your own risk.\n",
85-
"[NeMo W 2025-02-03 20:06:54 i18n_ipa:124] apply_to_oov_word=None, This means that some of words will remain unchanged if they are not handled by any of the rules in self.parse_one_word(). This may be intended if phonemes and chars are both valid inputs, otherwise, you may see unexpected deletions in your input.\n",
86-
"[NeMo W 2025-02-03 20:06:54 experimental:26] `<class 'nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer'>` is experimental and not ready for production yet. Use at your own risk.\n",
87-
"[NeMo W 2025-02-03 20:06:54 zh_cn_pinyin:100] apply_to_oov_word=None, This means that some of words will remain unchanged if they are not handled by any of the rules in self.parse_one_word(). This may be intended if phonemes and chars are both valid inputs, otherwise, you may see unexpected deletions in your input.\n",
88-
"You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
89-
]
90-
},
91-
{
92-
"ename": "TypeError",
93-
"evalue": "Transformer.__init__() got an unexpected keyword argument 'pos_emb'",
94-
"output_type": "error",
95-
"traceback": [
96-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
97-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
98-
"Cell \u001b[0;32mIn[3], line 28\u001b[0m\n\u001b[1;32m 24\u001b[0m model_cfg\u001b[38;5;241m.\u001b[39mtrain_ds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 25\u001b[0m model_cfg\u001b[38;5;241m.\u001b[39mvalidation_ds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mT5TTS_Model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_cfg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoading weights from checkpoint\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 30\u001b[0m ckpt \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(checkpoint_file)\n",
99-
"File \u001b[0;32m/home/rfejgin/NeMo/nemo/collections/tts/models/t5tts.py:153\u001b[0m, in \u001b[0;36mT5TTS_Model.__init__\u001b[0;34m(self, cfg, trainer)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdecoder_pretrain_synthesizer\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 151\u001b[0m \u001b[38;5;66;03m# Decoder pretrain synthesizer doesn't have transcript encoder/text embeddings\u001b[39;00m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtext_embedding \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mEmbedding(num_tokens, cfg\u001b[38;5;241m.\u001b[39membedding_dim)\n\u001b[0;32m--> 153\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mt5_encoder \u001b[38;5;241m=\u001b[39m \u001b[43mt5tts_transformer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mt5_encoder\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mt5_decoder \u001b[38;5;241m=\u001b[39m t5tts_transformer\u001b[38;5;241m.\u001b[39mTransformer(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mdict\u001b[39m(cfg\u001b[38;5;241m.\u001b[39mt5_decoder))\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfinal_proj \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mLinear(cfg\u001b[38;5;241m.\u001b[39mt5_decoder\u001b[38;5;241m.\u001b[39md_model, cfg\u001b[38;5;241m.\u001b[39mnum_audio_codebooks \u001b[38;5;241m*\u001b[39m cfg\u001b[38;5;241m.\u001b[39mnum_audio_tokens_per_codebook)\n",
100-
"\u001b[0;31mTypeError\u001b[0m: Transformer.__init__() got an unexpected keyword argument 'pos_emb'"
101-
]
102-
}
103-
],
63+
"outputs": [],
10464
"source": [
10565
"#hparams_file = \"yt_weight_0.25_plus18k__dim1536__enc3_fixes_hparams.yaml\"\n",
10666
"#checkpoint_file = \"yt_weight_0.25_plus18k__dim1536__enc3_fixes_val_loss_5.1870_epoch_25.ckpt\"\n",
10767
"\n",
10868
"#hparams_file = \"/data/t5_new_cp/configs/unnormalizedLalign005_singleencoder_kernel3_hparams.yaml\"\n",
10969
"#checkpoint_file = \"/data/t5_new_cp/checkpoints/unnormalizedLalign005_singleencoder_kernel3_epoch_20.ckpt\" #\"/datap/misc/continuouscheckpoints/edresson_epoch21.ckpt\"\n",
110-
"hparams_file = \"/datap/misc/continuouscheckpoints/yt_weight0.25_plus_18k_single_stage_decoder_context_kernel1_fixes_hparams.yaml\"\n",
111-
"checkpoint_file =\"/datap/misc/continuouscheckpoints/yt_weight0.25_plus_18k_single_stage_decoder_context_kernel1_fixes_epoch_61.ckpt\" \n",
70+
"#hparams_file = \"/datap/misc/continuouscheckpoints/yt_weight0.25_plus_18k_single_stage_decoder_context_kernel1_fixes_hparams.yaml\"\n",
71+
"#checkpoint_file =\"/datap/misc/continuouscheckpoints/yt_weight0.25_plus_18k_single_stage_decoder_context_kernel1_fixes_epoch_61.ckpt\" \n",
11272
"#hparams_file = \"/datap/misc/continuouscheckpoints/decoder_context_large_hparams.yaml\"\n",
11373
"#checkpoint_file =\"/datap/misc/continuouscheckpoints/decoder_context_large_epoch_14.ckpt\" \n",
11474
"\n",
@@ -219,24 +179,41 @@
219179
},
220180
{
221181
"cell_type": "code",
222-
"execution_count": 4,
182+
"execution_count": 12,
223183
"id": "74683d11",
224184
"metadata": {},
225185
"outputs": [],
226186
"source": [
227187
"usg_cfg = True\n",
228-
"cfg_scale = 1.8\n",
188+
"cfg_scale = 2.5\n",
229189
"audio_dir = \"/home/rfejgin/kb-snippets\"\n",
230190
"#audio_dir = \"/data/NV-RESTRICTED/JHSD/22khz\"\n",
231-
"texts = [\"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\n",
232-
" \"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\n",
233-
" \"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\n",
234-
" \"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\n",
235-
" \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\",\n",
236-
" \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\",\n",
237-
" \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\",\n",
238-
" \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\",\n",
239-
" \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\"]\n",
191+
"texts = [\"Let me confirm S D S D two two two two, one, two, four, four, h t t p, four, five, six, seven, eight. Is that correct?\",\n",
192+
" \"Let me confirm S D S D two two two two, one, two, four, four, h t t p, four, five, six, seven, eight. Is that correct?\",\n",
193+
" \"Let me confirm S D S D two two two two, one, two, four, four, h t t p, four, five, six, seven, eight. Is that correct?\",\n",
194+
" \"Let me confirm S D S D two two two two, one, two, four, four, h t t p, four, five, six, seven, eight. Is that correct?\",\n",
195+
" \"Let me confirm S D S D two two two two, one, two, four, four, h t t p, four, five, six, seven, eight. Is that correct?\",\n",
196+
" \"hello\",\n",
197+
" \"hello\",\n",
198+
" \"hello\",\n",
199+
" \"hello\",\n",
200+
" \"hello\",\n",
201+
"]\n",
202+
" #\"hi\", \"hi\",\"hi\",\"hi\",\"hi\",\n",
203+
" # \"Let me confirm that number: two, one, two, four, four, four, five, six, seven, eight. Is that correct?\",\n",
204+
" # \"Let me confirm that number: two, one, two, four, four, four, five, six, seven, eight. Is that correct?\",\n",
205+
" # \"Let me confirm that number: two, one, two, four, four, four, five, six, seven, eight. Is that correct?\",\n",
206+
" # \"Let me confirm that number: two, one, two, four, four, four, five, six, seven, eight. Is that correct?\",\n",
207+
" # \"Let me confirm that number: two, one, two, four, four, four, five, six, seven, eight. Is that correct?\"]\n",
208+
"# texts = [\"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\n",
209+
"# \"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\n",
210+
"# \"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\n",
211+
"# \"Our GPUs aren’t just processors; they are engines for discovery, powering breakthroughs in everything from self-driving cars to disease research.\",\n",
212+
"# \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\",\n",
213+
"# \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\",\n",
214+
"# \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\",\n",
215+
"# \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\",\n",
216+
"# \"When comparing Heavenly and Northstar ski resorts in Lake Tahoe, each offers unique advantages.\"]\n",
240217
"# texts = [\"NVIDIA's Riva is a powerful speech AI toolkit that offers state-of-the-art ASR and TTS capabilities.\",\n",
241218
"# 'The platform supports multiple languages and provides enterprise-grade speech technology through GPU-accelerated SDKs and APIs.',\n",
242219
"# 'What makes Riva unique is its ability to be customized for specific use cases while maintaining high performance and accuracy.',\n",
@@ -264,8 +241,9 @@
264241
" \"duration\": 4.89,\n",
265242
" \"text\": text,\n",
266243
" \"speaker\": \"dummy\",\n",
267-
" \"context_audio_filepath\": \"roy2_22050.wav\",#\"AMP20_KEYNOTE-VOOnly-44khz-16bit-mono_6.wav\",#\"adi-snippet1.wav\",\n",
268-
" \"context_audio_duration\": 4.89\n",
244+
" \"context_text\": \"Speaker and Emotion: | Language:en Dataset:Riva Speaker:Lindy_WIZWIKI |\"\n",
245+
" #\"context_audio_filepath\": \"roy2_22050.wav\",#\"AMP20_KEYNOTE-VOOnly-44khz-16bit-mono_6.wav\",#\"adi-snippet1.wav\",\n",
246+
" #\"context_audio_duration\": 4.89\n",
269247
" }\n",
270248
" entries.append(entry)\n",
271249
"data_samples = [DatasetSample(\n",
@@ -291,7 +269,7 @@
291269
},
292270
{
293271
"cell_type": "code",
294-
"execution_count": 5,
272+
"execution_count": 13,
295273
"id": "b7374d3f",
296274
"metadata": {},
297275
"outputs": [],
@@ -456,18 +434,19 @@
456434
"outputs": [],
457435
"source": [
458436
"print(f\"Checkpoint: {checkpoint_file}\")\n",
459-
"context_filepath = os.path.join(audio_dir, entry['context_audio_filepath'])\n",
460-
"display(Audio(context_filepath))\n"
437+
"if 'context_audio_filepath' in entry:\n",
438+
" context_filepath = os.path.join(audio_dir, entry['context_audio_filepath'])\n",
439+
" display(Audio(context_filepath))\n"
461440
]
462441
},
463442
{
464443
"cell_type": "code",
465-
"execution_count": null,
444+
"execution_count": 16,
466445
"id": "0a72ccec",
467446
"metadata": {},
468447
"outputs": [],
469448
"source": [
470-
"entry['context_audio_filepath']"
449+
"if 'context_audio_filepath' in entry: entry['context_audio_filepath']"
471450
]
472451
},
473452
{

0 commit comments

Comments
 (0)