This repository contains a PyTorch implementation of the Voxtral Codec component used to convert 24 kHz mono speech waveforms into discrete codes for TTS training.
The implementation is aligned with the research notes in this repository and includes:
- A causal convolution-transformer autoencoder
- A 292-dimensional latent split into 256 semantic and 36 acoustic dimensions
- Semantic vector quantization with an 8192-entry codebook
- Acoustic finite scalar quantization with 21 levels per dimension
- A multi-resolution STFT discriminator
- Whisper-based ASR distillation support
Pointwise summary:
- Input audio is expected at 24 kHz, mono.
- The waveform is split into non-overlapping patches of 240 samples.
- Patch frames are projected into a hidden representation with a causal kernel-7 convolution.
- Four encoder stages reduce the frame rate from 100 Hz to 12.5 Hz.
- The latent is split into:
- 256 semantic dimensions, quantized with VQ
- 36 acoustic dimensions, quantized with FSQ
- Each 12.5 Hz frame produces 37 tokens total:
- 1 semantic token
- 36 acoustic tokens
- The decoder reconstructs waveform patches from the quantized latent.
- Training uses feature matching, ASR distillation, L1 reconstruction, STFT magnitude reconstruction, and VQ commitment loss.
- Sample rate: 24,000 Hz
- Patch size: 240 samples
- Encoder patch projection: kernel size 7
- Encoder blocks: 4
- Encoder transformer layers per stage: 2 -> 2 -> 2 -> 2
- Sliding attention windows: 16 -> 8 -> 4 -> 2
- Encoder CNN strides: 2 -> 2 -> 2 -> 1
- Latent size: 292
- Semantic bottleneck: 256 dims with codebook size 8192
- Acoustic bottleneck: 36 dims with 21 FSQ levels each
- Target frame rate:
$24000 / (240 \times 2 \times 2 \times 2) = 12.5$ Hz - Approximate bitrate:
$12.5 \times (\log_2 8192 + 36 \times \log_2 21) \approx 2.14$ kbps
voxtral_codec/encoder.py: encoder and causal sliding-window transformer blocksvoxtral_codec/decoder.py: decoder and causal upsampling blocksvoxtral_codec/quantizer.py: semantic VQ and acoustic FSQvoxtral_codec/discriminator.py: multi-resolution STFT discriminatorvoxtral_codec/asr_distillation.py: Whisper-based semantic distillationvoxtral_codec/losses.py: reconstruction, STFT, feature matching, hinge GAN lossesvoxtral_codec/model.py: top-level codec modeltrain.py: real training entrypoint for audio datasetsdummy_train.py: smoke test training loop with synthetic datatests/test_model.py: unit tests
Install dependencies:
/usr/bin/python3 -m pip install -r requirements.txtIf you also want to run tests locally:
/usr/bin/python3 -m pip install pytestThe quickest way to verify the model wiring is to run the dummy training loop.
/usr/bin/python3 dummy_train.py --steps 2 --device cpuWhat this does:
- Builds a tiny Voxtral Codec configuration
- Builds the multi-resolution discriminator
- Runs generator and discriminator updates on random audio tensors
- Prints L1, STFT, feature-matching, VQ, adversarial, and discriminator losses
Run the test suite with:
/usr/bin/python3 -m pytest -qCurrent validated result in this workspace:
33 passed
Example command:
/usr/bin/python3 train.py \
--data_dir /path/to/wav24k \
--batch_size 8 \
--segment_sec 4 \
--max_steps 400000 \
--save_dir ./checkpointsImportant arguments:
--data_dir: directory containing.wavor.flacfiles--batch_size: batch size for training--segment_sec: segment length in seconds--use_asr: enable Whisper-based ASR distillation--whisper_model: Hugging Face Whisper model name--save_dir: checkpoint output directory
The generator-side objective implemented here is:
Where:
L_feature: discriminator feature matching lossL_ASR: Whisper-based semantic distillation lossL_L1: waveform reconstruction lossL_STFT: STFT magnitude reconstruction lossL_commit: VQ codebook plus commitment loss\gamma_t: decaying reconstruction weight
The discriminator uses a hinge-style objective, while the generator primarily relies on feature matching rather than a large explicit GAN generator loss.
The implementation includes the training-time behavior described in the notes:
- Semantic VQ:
- 50% quantized
- 50% passed through unquantized
- Acoustic FSQ:
- 50% quantized
- 25% dithered with uniform noise of magnitude
$1/L$ - 25% passed through unquantized
This codebase now covers the main codec pieces described in the notes, but a few paper details remain approximate rather than exact.
- Whisper distillation is implemented with decoder hidden states and cross-attention-derived soft alignment.
- The exact offline DTW-based head selection procedure described in the paper is not fully reproduced as a separate preprocessing stage.
- The default training code is practical for experimentation, not a claim of full paper reproduction at final-scale training.
The following commands were run successfully in this workspace:
/usr/bin/python3 -m pytest -qResult:
33 passed in 8.49s
Original paper:
- Mistral AI,
Voxtral TTSresearch paper: https://mistral.ai/static/research/voxtral-tts.pdf