Skip to content

Commit eb90621

Browse files
authored
Enable usage of Min_P Sampler, modify other sampler settings (#155)
This commit enables the use of min_p sampler, as well as giving the gradio app sliders to configure top_p, min_p and repetition_penalty Changed the defaults to min_p=0.05, rep_pen=1.2, top_p=1.0 min_p=0.00 - disables top_p=1.00 - disables
1 parent a65a3f8 commit eb90621

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

gradio_tts_app.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def load_model():
2121
return model
2222

2323

24-
def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw):
24+
def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, min_p, top_p, repetition_penalty):
2525
if model is None:
2626
model = ChatterboxTTS.from_pretrained(DEVICE)
2727

@@ -34,6 +34,9 @@ def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num
3434
exaggeration=exaggeration,
3535
temperature=temperature,
3636
cfg_weight=cfgw,
37+
min_p=min_p,
38+
top_p=top_p,
39+
repetition_penalty=repetition_penalty,
3740
)
3841
return (model.sr, wav.squeeze(0).numpy())
3942

@@ -55,6 +58,9 @@ def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num
5558
with gr.Accordion("More options", open=False):
5659
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
5760
temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
61+
min_p = gr.Slider(0.00, 1.00, step=0.01, label="min_p || Newer Sampler. Recommend 0.02 > 0.1. Handles Higher Temperatures better. 0.00 Disables", value=0.05)
62+
top_p = gr.Slider(0.00, 1.00, step=0.01, label="top_p || Original Sampler. 1.0 Disables(recommended). Original 0.8", value=1.00)
63+
repetition_penalty = gr.Slider(1.00, 2.00, step=0.1, label="repetition_penalty", value=1.2)
5864

5965
run_btn = gr.Button("Generate", variant="primary")
6066

@@ -73,6 +79,9 @@ def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num
7379
temp,
7480
seed_num,
7581
cfg_weight,
82+
min_p,
83+
top_p,
84+
repetition_penalty,
7685
],
7786
outputs=audio_output,
7887
)

src/chatterbox/models/t3/t3.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn.functional as F
99
from torch import nn, Tensor
1010
from transformers import LlamaModel, LlamaConfig
11-
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
11+
from transformers.generation.logits_process import MinPLogitsWarper, RepetitionPenaltyLogitsProcessor, TopPLogitsWarper
1212

1313
from .modules.learned_pos_emb import LearnedPositionEmbeddings
1414

@@ -217,9 +217,10 @@ def inference(
217217
stop_on_eos=True,
218218
do_sample=True,
219219
temperature=0.8,
220-
top_p=0.8,
220+
min_p=0.05,
221+
top_p=1.00,
221222
length_penalty=1.0,
222-
repetition_penalty=2.0,
223+
repetition_penalty=1.2,
223224
cfg_weight=0,
224225
):
225226
"""
@@ -271,7 +272,7 @@ def inference(
271272
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
272273
# num_return_sequences=num_return_sequences,
273274
# temperature=temperature,
274-
# top_p=top_p,
275+
# min_p=min_p,
275276
# length_penalty=length_penalty,
276277
# repetition_penalty=repetition_penalty,
277278
# do_sample=do_sample,
@@ -298,8 +299,9 @@ def inference(
298299
predicted = [] # To store the predicted tokens
299300

300301
# Instantiate the logits processors.
302+
min_p_warper = MinPLogitsWarper(min_p=min_p)
301303
top_p_warper = TopPLogitsWarper(top_p=top_p)
302-
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
304+
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
303305

304306
# ---- Initial Forward Pass (no kv_cache yet) ----
305307
output = self.patched_model(
@@ -331,6 +333,7 @@ def inference(
331333

332334
# Apply repetition penalty and top‑p filtering.
333335
logits = repetition_penalty_processor(generated_ids, logits)
336+
logits = min_p_warper(None, logits)
334337
logits = top_p_warper(None, logits)
335338

336339
# Convert logits to probabilities and sample the next token.

src/chatterbox/tts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
208208
def generate(
209209
self,
210210
text,
211+
repetition_penalty=1.2,
212+
min_p=0.05,
213+
top_p=1.0,
211214
audio_prompt_path=None,
212215
exaggeration=0.5,
213216
cfg_weight=0.5,
@@ -246,6 +249,9 @@ def generate(
246249
max_new_tokens=1000, # TODO: use the value in config
247250
temperature=temperature,
248251
cfg_weight=cfg_weight,
252+
repetition_penalty=repetition_penalty,
253+
min_p=min_p,
254+
top_p=top_p,
249255
)
250256
# Extract only the conditional batch.
251257
speech_tokens = speech_tokens[0]

0 commit comments

Comments
 (0)