Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions INFERENCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,119 @@ for (sampling_rate, audio_chunk) in generate(text, description, chunk_size_in_s)
print(audio_chunk.shape)
```

### Async Streamer

If you want to overlap computations, you can use asynchronous streamer, you can check [AsyncParlerTTSStreamer](https://github.com/huggingface/parler-tts/blob/main/parler_tts/streamer.py).

Here's how to use it.

```py
import torch
from parler_tts import ParlerTTSForConditionalGeneration, AsyncParlerTTSStreamer
from transformers import AutoTokenizer
from threading import Thread
import asyncio

torch_device = "cuda:0" # Use "mps" for Mac
torch_dtype = torch.bfloat16
model_name = "parler-tts/parler-tts-mini-v1"

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name,
).to(torch_device, dtype=torch_dtype)

sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate
play_steps_in_s = 0.5

async def main(text, request_id):
play_steps = int(frame_rate * play_steps_in_s)
streamer = AsyncParlerTTSStreamer(model, device=device, play_steps=play_steps)
description = "A female speaker with a slightly low-pitched voice"
inputs = tokenizer(description, return_tensors="pt").to(device)
prompt = tokenizer(text, return_tensors="pt").to(device)

generation_kwargs = dict(
input_ids=inputs.input_ids,
prompt_input_ids=prompt.input_ids,
attention_mask=inputs.attention_mask,
prompt_attention_mask=prompt.attention_mask,
streamer=streamer,
do_sample=True,
temperature=1.0,
min_new_tokens=10,
decode=False, # to skip final decode to save computation
)

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

async for new_audio in streamer:
if new_audio.shape[0] == 0:
break

print(f"Request ID: {request_id}, Sample of length: {round(new_audio.shape[0] / sampling_rate, 4)} seconds")

prompts = [
"that can generate high-quality, natural sounding speech with features that can be controlled using a simple text prompt",
"which aims to provide the community with TTS training resources and dataset pre-processing code.",
"include the term 'very clear audio' to generate the highest quality audio, and 'very noisy audio' for high levels of background noise"
]
tasks = []
for no, p in enumerate(prompts):
task = asyncio.create_task(main(p, no))
tasks.append(task)

await asyncio.gather(*tasks)
```

Output,
```
Request ID: 2, Sample of length: 0.329 seconds
Request ID: 0, Sample of length: 0.329 seconds
Request ID: 1, Sample of length: 0.329 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 1, Sample of length: 0.4992 seconds
Request ID: 0, Sample of length: 0.4489 seconds
Request ID: 1, Sample of length: 0.2283 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4992 seconds
Request ID: 2, Sample of length: 0.4489 seconds
```

## Batch generation

Batching means combining operations for multiple samples to bring the overall time spent generating the samples lower than generating sample per sample.
Expand Down
2 changes: 1 addition & 1 deletion parler_tts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
build_delay_pattern_mask,
)

from .streamer import ParlerTTSStreamer
from .streamer import ParlerTTSStreamer, AsyncParlerTTSStreamer

from importlib.metadata import version
from packaging.version import Version
Expand Down
138 changes: 72 additions & 66 deletions parler_tts/modeling_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2276,7 +2276,6 @@ def generate(
output_ids = outputs.sequences
else:
output_ids = outputs

# apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])

Expand All @@ -2298,6 +2297,7 @@ def generate(
return output_ids



@add_start_docstrings(
"The composite Parler-TTS model with a text encoder, audio encoder and ParlerTTS decoder, "
"for music generation tasks with one or both of text and audio prompts.",
Expand Down Expand Up @@ -3327,6 +3327,7 @@ def generate(
stopping_criteria: Optional[StoppingCriteriaList] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
decode: Optional[bool] = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -3370,6 +3371,8 @@ def generate(
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
decode (`bool`, *optional*, defaults to `True`):
Decode the output from LLM using DACEncoder.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
Expand Down Expand Up @@ -3577,80 +3580,83 @@ def generate(
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
)

if generation_config.return_dict_in_generate:
output_ids = outputs.sequences
else:
output_ids = outputs
if decode:
if generation_config.return_dict_in_generate:
output_ids = outputs.sequences
else:
output_ids = outputs

# Apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# Apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])

# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)

mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)

# append the frame dimension back to the audio codes
output_ids = output_ids[None, ...]
# append the frame dimension back to the audio codes
output_ids = output_ids[None, ...]

audio_decode_kwargs = {}
if self.use_audio_scales:
audio_scales = model_kwargs.get("audio_scales")
if audio_scales is None:
audio_scales = [None] * batch_size
audio_decode_kwargs["audio_scales"] = audio_scales
audio_decode_kwargs = {}
if self.use_audio_scales:
audio_scales = model_kwargs.get("audio_scales")
if audio_scales is None:
audio_scales = [None] * batch_size
audio_decode_kwargs["audio_scales"] = audio_scales


if not self.use_4dim_audio_codes:
# remove chunk dim
output_ids = output_ids.squeeze(0)


decode_sequentially = (
generation_config.bos_token_id in output_ids
or generation_config.pad_token_id in output_ids
or generation_config.eos_token_id in output_ids
)
if not decode_sequentially:
output_values = self.audio_encoder.decode(
audio_codes=output_ids,
**audio_decode_kwargs,
).audio_values.squeeze(1)
output_lengths = [audio.shape[0] for audio in output_values]
else:
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id] if self.use_4dim_audio_codes else output_ids[sample_id]
sample_mask = (sample >= self.audio_encoder.config.codebook_size)
sample_mask = (sample_mask.sum(dim=(0, 1)) == 0) if self.use_4dim_audio_codes else (sample_mask.sum(dim=0) == 0)
single_audio_decode_kwargs = {}
if self.use_audio_scales:
single_audio_decode_kwargs["audio_scales"] = [audio_decode_kwargs["audio_scales"][sample_id]]
if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask] if self.use_4dim_audio_codes else sample[:, sample_mask]
sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **single_audio_decode_kwargs).audio_values
sample = sample if sample.ndim == 3 else sample.unsqueeze(0)
output_values.append(sample.transpose(0, 2))
else:
output_values.append(torch.zeros((1, 1, 1)).to(self.device))
output_lengths = [audio.shape[0] for audio in output_values]
output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
if not self.use_4dim_audio_codes:
# remove chunk dim
output_ids = output_ids.squeeze(0)


decode_sequentially = (
generation_config.bos_token_id in output_ids
or generation_config.pad_token_id in output_ids
or generation_config.eos_token_id in output_ids
)
if generation_config.return_dict_in_generate:
outputs["audios_length"] = output_lengths
outputs.sequences = output_values
return outputs
if not decode_sequentially:
output_values = self.audio_encoder.decode(
audio_codes=output_ids,
**audio_decode_kwargs,
).audio_values.squeeze(1)
output_lengths = [audio.shape[0] for audio in output_values]
else:
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id] if self.use_4dim_audio_codes else output_ids[sample_id]
sample_mask = (sample >= self.audio_encoder.config.codebook_size)
sample_mask = (sample_mask.sum(dim=(0, 1)) == 0) if self.use_4dim_audio_codes else (sample_mask.sum(dim=0) == 0)
single_audio_decode_kwargs = {}
if self.use_audio_scales:
single_audio_decode_kwargs["audio_scales"] = [audio_decode_kwargs["audio_scales"][sample_id]]
if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask] if self.use_4dim_audio_codes else sample[:, sample_mask]
sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **single_audio_decode_kwargs).audio_values
sample = sample if sample.ndim == 3 else sample.unsqueeze(0)
output_values.append(sample.transpose(0, 2))
else:
output_values.append(torch.zeros((1, 1, 1)).to(self.device))
output_lengths = [audio.shape[0] for audio in output_values]
output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
if generation_config.return_dict_in_generate:
outputs["audios_length"] = output_lengths
outputs.sequences = output_values
return outputs
else:
return output_values
else:
return output_values
return outputs

def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
Expand Down
Loading