Skip to content

Commit 5761866

Browse files
authored
[llm] Beef up wav loader to read audio format 3 (float format) (pytorch#15452)
This PR enables audio format 3 (IEEE float format) on `wav_loader.h`, which allows direct reading of float values without normalization. Also adds a unit test.
1 parent efc2be7 commit 5761866

File tree

2 files changed

+141
-11
lines changed

2 files changed

+141
-11
lines changed

extension/llm/runner/test/test_wav_loader.cpp

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@
1919

2020
using executorch::extension::llm::kOneOverIntMax;
2121
using executorch::extension::llm::kOneOverShortMax;
22+
using executorch::extension::llm::kWavFormatIeeeFloat;
2223
using executorch::extension::llm::load_wav_audio_data;
2324
using executorch::extension::llm::load_wav_header;
2425
using executorch::extension::llm::WavHeader;
2526
using executorch::extension::testing::TempFile;
2627

2728
namespace {
2829

30+
// WAV file format constants
31+
constexpr uint32_t kWavHeaderSizeBeforeData = 36;
32+
constexpr uint32_t kWavHeaderSizeWithData = 44;
33+
2934
// Test fixture to ensure PAL initialization
3035
class WavLoaderTest : public ::testing::Test {
3136
protected:
@@ -51,20 +56,27 @@ void append_le32(std::vector<uint8_t>& out, uint32_t value) {
5156
out.push_back(static_cast<uint8_t>((value >> 24) & 0xFF));
5257
}
5358

59+
void append_float(std::vector<uint8_t>& out, float value) {
60+
const auto* bytes = reinterpret_cast<const uint8_t*>(&value);
61+
for (size_t i = 0; i < sizeof(float); ++i) {
62+
out.push_back(bytes[i]);
63+
}
64+
}
65+
5466
std::vector<uint8_t> make_pcm_wav_bytes(
5567
int bits_per_sample,
5668
const std::vector<int32_t>& samples,
5769
uint16_t num_channels = 1,
5870
uint32_t sample_rate = 16000) {
59-
const size_t bytes_per_sample = static_cast<size_t>(bits_per_sample / 8);
60-
const uint32_t subchunk2_size =
71+
const auto bytes_per_sample = static_cast<size_t>(bits_per_sample / 8);
72+
const auto subchunk2_size =
6173
static_cast<uint32_t>(samples.size() * bytes_per_sample);
6274
const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample;
6375
const uint16_t block_align = num_channels * bytes_per_sample;
64-
const uint32_t chunk_size = 36 + subchunk2_size;
76+
const auto chunk_size = kWavHeaderSizeBeforeData + subchunk2_size;
6577

6678
std::vector<uint8_t> bytes;
67-
bytes.reserve(44 + subchunk2_size);
79+
bytes.reserve(kWavHeaderSizeWithData + subchunk2_size);
6880

6981
append_bytes(bytes, "RIFF");
7082
append_le32(bytes, chunk_size);
@@ -91,6 +103,75 @@ std::vector<uint8_t> make_pcm_wav_bytes(
91103
return bytes;
92104
}
93105

106+
std::vector<uint8_t> make_float_wav_bytes(
107+
const std::vector<float>& samples,
108+
uint16_t num_channels = 1,
109+
uint32_t sample_rate = 16000) {
110+
const auto bytes_per_sample = sizeof(float);
111+
const auto subchunk2_size =
112+
static_cast<uint32_t>(samples.size() * bytes_per_sample);
113+
const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample;
114+
const uint16_t block_align = num_channels * bytes_per_sample;
115+
const auto chunk_size = kWavHeaderSizeBeforeData + subchunk2_size;
116+
117+
std::vector<uint8_t> bytes;
118+
bytes.reserve(kWavHeaderSizeWithData + subchunk2_size);
119+
120+
append_bytes(bytes, "RIFF");
121+
append_le32(bytes, chunk_size);
122+
append_bytes(bytes, "WAVE");
123+
append_bytes(bytes, "fmt ");
124+
append_le32(bytes, 16);
125+
append_le16(bytes, 3); // AudioFormat IEEE Float
126+
append_le16(bytes, num_channels);
127+
append_le32(bytes, sample_rate);
128+
append_le32(bytes, byte_rate);
129+
append_le16(bytes, block_align);
130+
append_le16(bytes, 32); // bits per sample
131+
append_bytes(bytes, "data");
132+
append_le32(bytes, subchunk2_size);
133+
134+
for (float sample : samples) {
135+
append_float(bytes, sample);
136+
}
137+
138+
return bytes;
139+
}
140+
141+
std::vector<uint8_t> make_wav_bytes_with_format(
142+
uint16_t audio_format,
143+
int bits_per_sample,
144+
const std::vector<uint8_t>& sample_data,
145+
uint16_t num_channels = 1,
146+
uint32_t sample_rate = 16000) {
147+
const auto bytes_per_sample = static_cast<size_t>(bits_per_sample / 8);
148+
const auto subchunk2_size = static_cast<uint32_t>(sample_data.size());
149+
const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample;
150+
const uint16_t block_align = num_channels * bytes_per_sample;
151+
const auto chunk_size = kWavHeaderSizeBeforeData + subchunk2_size;
152+
153+
std::vector<uint8_t> bytes;
154+
bytes.reserve(kWavHeaderSizeWithData + subchunk2_size);
155+
156+
append_bytes(bytes, "RIFF");
157+
append_le32(bytes, chunk_size);
158+
append_bytes(bytes, "WAVE");
159+
append_bytes(bytes, "fmt ");
160+
append_le32(bytes, 16);
161+
append_le16(bytes, audio_format);
162+
append_le16(bytes, num_channels);
163+
append_le32(bytes, sample_rate);
164+
append_le32(bytes, byte_rate);
165+
append_le16(bytes, block_align);
166+
append_le16(bytes, static_cast<uint16_t>(bits_per_sample));
167+
append_bytes(bytes, "data");
168+
append_le32(bytes, subchunk2_size);
169+
170+
bytes.insert(bytes.end(), sample_data.begin(), sample_data.end());
171+
172+
return bytes;
173+
}
174+
94175
} // namespace
95176

96177
TEST_F(WavLoaderTest, LoadHeaderParsesPcmMetadata) {
@@ -153,3 +234,31 @@ TEST_F(WavLoaderTest, LoadHeaderReturnsNullWhenMagicMissing) {
153234
std::unique_ptr<WavHeader> header = load_wav_header(file.path());
154235
EXPECT_EQ(header, nullptr);
155236
}
237+
238+
TEST_F(WavLoaderTest, LoadAudioDataFloatFormatReadsDirectly) {
239+
const std::vector<float> samples = {0.0f, 0.5f, -0.5f, 1.0f, -1.0f};
240+
const std::vector<uint8_t> wav_bytes = make_float_wav_bytes(samples);
241+
TempFile file(wav_bytes.data(), wav_bytes.size());
242+
243+
std::unique_ptr<WavHeader> header = load_wav_header(file.path());
244+
ASSERT_NE(header, nullptr);
245+
EXPECT_EQ(header->AudioFormat, kWavFormatIeeeFloat);
246+
EXPECT_EQ(header->bitsPerSample, 32);
247+
248+
std::vector<float> audio = load_wav_audio_data(file.path());
249+
ASSERT_EQ(audio.size(), samples.size());
250+
251+
for (size_t i = 0; i < samples.size(); ++i) {
252+
EXPECT_FLOAT_EQ(audio[i], samples[i]);
253+
}
254+
}
255+
256+
TEST_F(WavLoaderTest, LoadAudioDataRejectsUnsupportedFormat) {
257+
const std::vector<uint8_t> sample_data = {0, 0, 0, 0};
258+
const std::vector<uint8_t> wav_bytes =
259+
make_wav_bytes_with_format(0x0006, 16, sample_data);
260+
TempFile file(wav_bytes.data(), wav_bytes.size());
261+
262+
EXPECT_DEATH(
263+
{ load_wav_audio_data(file.path()); }, "Unsupported audio format");
264+
}

extension/llm/runner/wav_loader.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
#include <executorch/runtime/platform/log.h>
2323

2424
namespace executorch::extension::llm {
25+
// See https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
26+
constexpr uint16_t kWavFormatPcm = 0x0001;
27+
constexpr uint16_t kWavFormatIeeeFloat = 0x0003;
2528

2629
constexpr float kOneOverIntMax = 1 / static_cast<float>(INT32_MAX);
2730
constexpr float kOneOverShortMax = 1 / static_cast<float>(INT16_MAX);
@@ -168,24 +171,42 @@ inline std::vector<float> load_wav_audio_data(const std::string& fp) {
168171
size_t data_offset = header->dataOffset;
169172
size_t data_size = header->Subchunk2Size;
170173
int bits_per_sample = header->bitsPerSample;
174+
int audio_format = header->AudioFormat;
175+
176+
if (audio_format != kWavFormatPcm && audio_format != kWavFormatIeeeFloat) {
177+
ET_CHECK_MSG(
178+
false,
179+
"Unsupported audio format: 0x%04X. Only PCM (0x%04X) and IEEE Float (0x%04X) are supported.",
180+
audio_format,
181+
kWavFormatPcm,
182+
kWavFormatIeeeFloat);
183+
}
171184

172185
std::vector<float> audio_data;
173186

174187
if (bits_per_sample == 32) {
175188
size_t num_samples = data_size / 4;
176-
audio_data.resize(num_samples);
177-
const int32_t* input_buffer =
178-
reinterpret_cast<const int32_t*>(data + data_offset);
179189

180-
for (size_t i = 0; i < num_samples; ++i) {
181-
audio_data[i] = static_cast<float>(
182-
static_cast<double>(input_buffer[i]) * kOneOverIntMax);
190+
if (audio_format == kWavFormatIeeeFloat) {
191+
// IEEE float format - read directly as floats
192+
const float* input_buffer =
193+
reinterpret_cast<const float*>(data + data_offset);
194+
audio_data.assign(input_buffer, input_buffer + num_samples);
195+
} else {
196+
// PCM integer format - normalize from int32
197+
const int32_t* input_buffer =
198+
reinterpret_cast<const int32_t*>(data + data_offset);
199+
audio_data.resize(num_samples);
200+
for (size_t i = 0; i < num_samples; ++i) {
201+
audio_data[i] = static_cast<float>(
202+
static_cast<double>(input_buffer[i]) * kOneOverIntMax);
203+
}
183204
}
184205
} else if (bits_per_sample == 16) {
185206
size_t num_samples = data_size / 2;
186-
audio_data.resize(num_samples);
187207
const int16_t* input_buffer =
188208
reinterpret_cast<const int16_t*>(data + data_offset);
209+
audio_data.resize(num_samples);
189210

190211
for (size_t i = 0; i < num_samples; ++i) {
191212
audio_data[i] = static_cast<float>(

0 commit comments

Comments
 (0)