Skip to content

Commit 2d1ac22

Browse files
mkopcinschmjkbIgorSwat
authored
feat: Remove stft calculation within the encoder (#658)
## Description The Whisper model export now takes in a plain waveform instead of pre-computed STFT. This PR aims to change the current API to accept waveforms instead. Before merging this, make sure to re-export all the existing Whisper models with the new export script. ### Introduces a breaking change? - [ ] Yes - [x] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [x] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: chmjkb <[email protected]> Co-authored-by: Jakub Chmura <[email protected]> Co-authored-by: IgorSwat <[email protected]>
1 parent 73e7aea commit 2d1ac22

File tree

7 files changed

+61
-103
lines changed

7 files changed

+61
-103
lines changed
Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
#include <algorithm>
21
#include <cstddef>
3-
#include <limits>
42
#include <math.h>
53
#include <rnexecutorch/data_processing/FFT.h>
64
#include <rnexecutorch/data_processing/dsp.h>
@@ -18,48 +16,4 @@ std::vector<float> hannWindow(size_t size) {
1816
return window;
1917
}
2018

21-
std::vector<float> stftFromWaveform(std::span<const float> waveform,
22-
size_t fftWindowSize, size_t hopSize) {
23-
// Initialize FFT
24-
FFT fft(fftWindowSize);
25-
26-
const auto numFrames = 1 + (waveform.size() - fftWindowSize) / hopSize;
27-
const auto numBins = fftWindowSize / 2;
28-
const auto hann = hannWindow(fftWindowSize);
29-
auto inBuffer = std::vector<float>(fftWindowSize);
30-
auto outBuffer = std::vector<std::complex<float>>(fftWindowSize);
31-
32-
// Output magnitudes in dB
33-
std::vector<float> magnitudes;
34-
magnitudes.reserve(numFrames * numBins);
35-
const auto magnitudeScale = 1.0f / static_cast<float>(fftWindowSize);
36-
constexpr auto epsilon = std::numeric_limits<float>::epsilon();
37-
constexpr auto dbConversionFactor = 20.0f;
38-
39-
for (size_t t = 0; t < numFrames; ++t) {
40-
const size_t offset = t * hopSize;
41-
// Clear the input buffer first
42-
std::ranges::fill(inBuffer, 0.0f);
43-
44-
// Fill frame with windowed signal
45-
const size_t samplesToRead =
46-
std::min(fftWindowSize, waveform.size() - offset);
47-
for (size_t i = 0; i < samplesToRead; i++) {
48-
inBuffer[i] = waveform[offset + i] * hann[i];
49-
}
50-
51-
fft.doFFT(inBuffer.data(), outBuffer);
52-
53-
// Calculate magnitudes in dB (only positive frequencies)
54-
for (size_t i = 0; i < numBins; i++) {
55-
const auto magnitude = std::abs(outBuffer[i]) * magnitudeScale;
56-
const auto magnitude_db =
57-
dbConversionFactor * log10f(magnitude + epsilon);
58-
magnitudes.push_back(magnitude_db);
59-
}
60-
}
61-
62-
return magnitudes;
63-
}
64-
6519
} // namespace rnexecutorch::dsp

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ BaseModel::BaseModel(const std::string &modelSource,
3030
}
3131

3232
std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
33-
int32_t index) {
33+
int32_t index) const {
3434
if (!module_) {
3535
throw std::runtime_error("Model not loaded: Cannot get input shape");
3636
}
@@ -56,7 +56,7 @@ std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
5656
}
5757

5858
std::vector<std::vector<int32_t>>
59-
BaseModel::getAllInputShapes(std::string methodName) {
59+
BaseModel::getAllInputShapes(std::string methodName) const {
6060
if (!module_) {
6161
throw std::runtime_error("Model not loaded: Cannot get all input shapes");
6262
}
@@ -88,7 +88,7 @@ BaseModel::getAllInputShapes(std::string methodName) {
8888
/// to JS. It is not meant to be used within C++. If you want to call forward
8989
/// from C++ on a BaseModel, please use BaseModel::forward.
9090
std::vector<JSTensorViewOut>
91-
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
91+
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const {
9292
if (!module_) {
9393
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
9494
}
@@ -136,7 +136,7 @@ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
136136
}
137137

138138
Result<executorch::runtime::MethodMeta>
139-
BaseModel::getMethodMeta(const std::string &methodName) {
139+
BaseModel::getMethodMeta(const std::string &methodName) const {
140140
if (!module_) {
141141
throw std::runtime_error("Model not loaded: Cannot get method meta!");
142142
}
@@ -161,7 +161,7 @@ BaseModel::forward(const std::vector<EValue> &input_evalues) const {
161161

162162
Result<std::vector<EValue>>
163163
BaseModel::execute(const std::string &methodName,
164-
const std::vector<EValue> &input_value) {
164+
const std::vector<EValue> &input_value) const {
165165
if (!module_) {
166166
throw std::runtime_error("Model not loaded, cannot run execute.");
167167
}
@@ -175,7 +175,7 @@ std::size_t BaseModel::getMemoryLowerBound() const noexcept {
175175
void BaseModel::unload() noexcept { module_.reset(nullptr); }
176176

177177
std::vector<int32_t>
178-
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) {
178+
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const {
179179
auto sizes = tensor.sizes();
180180
return std::vector<int32_t>(sizes.begin(), sizes.end());
181181
}

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,20 @@ class BaseModel {
2525
Module::LoadMode loadMode = Module::LoadMode::MmapUseMlockIgnoreErrors);
2626
std::size_t getMemoryLowerBound() const noexcept;
2727
void unload() noexcept;
28-
std::vector<int32_t> getInputShape(std::string method_name, int32_t index);
28+
std::vector<int32_t> getInputShape(std::string method_name,
29+
int32_t index) const;
2930
std::vector<std::vector<int32_t>>
30-
getAllInputShapes(std::string methodName = "forward");
31+
getAllInputShapes(std::string methodName = "forward") const;
3132
std::vector<JSTensorViewOut>
32-
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
33+
forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const;
3334
Result<std::vector<EValue>> forward(const EValue &input_value) const;
3435
Result<std::vector<EValue>>
3536
forward(const std::vector<EValue> &input_value) const;
36-
Result<std::vector<EValue>> execute(const std::string &methodName,
37-
const std::vector<EValue> &input_value);
37+
Result<std::vector<EValue>>
38+
execute(const std::string &methodName,
39+
const std::vector<EValue> &input_value) const;
3840
Result<executorch::runtime::MethodMeta>
39-
getMethodMeta(const std::string &methodName);
41+
getMethodMeta(const std::string &methodName) const;
4042

4143
protected:
4244
// If possible, models should not use the JS runtime to keep JSI internals
@@ -49,7 +51,8 @@ class BaseModel {
4951
std::size_t memorySizeLowerBound{0};
5052

5153
private:
52-
std::vector<int32_t> getTensorShape(const executorch::aten::Tensor &tensor);
54+
std::vector<int32_t>
55+
getTensorShape(const executorch::aten::Tensor &tensor) const;
5356
};
5457
} // namespace models
5558

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "ASR.h"
55
#include "executorch/extension/tensor/tensor_ptr.h"
66
#include "rnexecutorch/data_processing/Numerical.h"
7-
#include "rnexecutorch/data_processing/dsp.h"
87
#include "rnexecutorch/data_processing/gzip.h"
98

109
namespace rnexecutorch::models::speech_to_text::asr {
@@ -37,8 +36,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const {
3736
return seq;
3837
}
3938

40-
GenerationResult ASR::generate(std::span<const float> waveform,
41-
float temperature,
39+
GenerationResult ASR::generate(std::span<float> waveform, float temperature,
4240
const DecodingOptions &options) const {
4341
std::vector<float> encoderOutput = this->encode(waveform);
4442

@@ -94,7 +92,7 @@ float ASR::getCompressionRatio(const std::string &text) const {
9492
}
9593

9694
std::vector<Segment>
97-
ASR::generateWithFallback(std::span<const float> waveform,
95+
ASR::generateWithFallback(std::span<float> waveform,
9896
const DecodingOptions &options) const {
9997
std::vector<float> temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
10098
std::vector<int32_t> bestTokens;
@@ -209,7 +207,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
209207
return wordObjs;
210208
}
211209

212-
std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
210+
std::vector<Segment> ASR::transcribe(std::span<float> waveform,
213211
const DecodingOptions &options) const {
214212
int32_t seek = 0;
215213
std::vector<Segment> results;
@@ -218,7 +216,7 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
218216
int32_t start = seek * ASR::kSamplingRate;
219217
const auto end = std::min<int32_t>(
220218
(seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
221-
std::span<const float> chunk = waveform.subspan(start, end - start);
219+
auto chunk = waveform.subspan(start, end - start);
222220

223221
if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
224222
break;
@@ -246,19 +244,12 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
246244
return results;
247245
}
248246

249-
std::vector<float> ASR::encode(std::span<const float> waveform) const {
250-
constexpr int32_t fftWindowSize = 512;
251-
constexpr int32_t stftHopLength = 160;
252-
constexpr int32_t innerDim = 256;
253-
254-
std::vector<float> preprocessedData =
255-
dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
256-
const auto numFrames =
257-
static_cast<int32_t>(preprocessedData.size()) / innerDim;
258-
std::vector<int32_t> inputShape = {numFrames, innerDim};
247+
std::vector<float> ASR::encode(std::span<float> waveform) const {
248+
auto inputShape = {static_cast<int32_t>(waveform.size())};
259249

260250
const auto modelInputTensor = executorch::extension::make_tensor_ptr(
261-
std::move(inputShape), std::move(preprocessedData));
251+
std::move(inputShape), waveform.data(),
252+
executorch::runtime::etensor::ScalarType::Float);
262253
const auto encoderResult = this->encoder->forward(modelInputTensor);
263254

264255
if (!encoderResult.ok()) {
@@ -268,7 +259,7 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
268259
}
269260

270261
const auto decoderOutputTensor = encoderResult.get().at(0).toTensor();
271-
const int32_t outputNumel = decoderOutputTensor.numel();
262+
const auto outputNumel = decoderOutputTensor.numel();
272263

273264
const float *const dataPtr = decoderOutputTensor.const_data_ptr<float>();
274265
return {dataPtr, dataPtr + outputNumel};
@@ -277,8 +268,10 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
277268
std::vector<float> ASR::decode(std::span<int32_t> tokens,
278269
std::span<float> encoderOutput) const {
279270
std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
271+
auto tokensLong = std::vector<int64_t>(tokens.begin(), tokens.end());
272+
280273
auto tokenTensor = executorch::extension::make_tensor_ptr(
281-
std::move(tokenShape), tokens.data(), ScalarType::Int);
274+
tokenShape, tokensLong.data(), ScalarType::Long);
282275

283276
const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
284277
std::vector<int32_t> encShape = {1, ASR::kNumFrames,

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class ASR {
1414
const models::BaseModel *decoder,
1515
const TokenizerModule *tokenizer);
1616
std::vector<types::Segment>
17-
transcribe(std::span<const float> waveform,
17+
transcribe(std::span<float> waveform,
1818
const types::DecodingOptions &options) const;
19-
std::vector<float> encode(std::span<const float> waveform) const;
19+
std::vector<float> encode(std::span<float> waveform) const;
2020
std::vector<float> decode(std::span<int32_t> tokens,
2121
std::span<float> encoderOutput) const;
2222

@@ -44,11 +44,10 @@ class ASR {
4444

4545
std::vector<int32_t>
4646
getInitialSequence(const types::DecodingOptions &options) const;
47-
types::GenerationResult generate(std::span<const float> waveform,
48-
float temperature,
47+
types::GenerationResult generate(std::span<float> waveform, float temperature,
4948
const types::DecodingOptions &options) const;
5049
std::vector<types::Segment>
51-
generateWithFallback(std::span<const float> waveform,
50+
generateWithFallback(std::span<float> waveform,
5251
const types::DecodingOptions &options) const;
5352
std::vector<types::Segment>
5453
calculateWordLevelTimestamps(std::span<const int32_t> tokens,

packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <array>
77
#include <functional>
88
#include <numeric>
9-
#include <ranges>
109
#include <vector>
1110

1211
namespace rnexecutorch::models::voice_activity_detection {
@@ -158,4 +157,4 @@ VoiceActivityDetection::postprocess(const std::vector<float> &scores,
158157
return speechSegments;
159158
}
160159

161-
} // namespace rnexecutorch::models::voice_activity_detection
160+
} // namespace rnexecutorch::models::voice_activity_detection

packages/react-native-executorch/src/constants/modelUrls.ts

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -307,29 +307,32 @@ export const STYLE_TRANSFER_UDNIE = {
307307
};
308308

309309
// S2T
310-
const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/tokenizer.json`;
311-
const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`;
312-
const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`;
310+
const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${NEXT_VERSION_TAG}/tokenizer.json`;
311+
const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`;
312+
const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`;
313313

314-
const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/tokenizer.json`;
315-
const WHISPER_BASE_EN_ENCODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_encoder_xnnpack.pte`;
316-
const WHISPER_BASE_EN_DECODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_decoder_xnnpack.pte`;
314+
const WHISPER_TINY_EN_ENCODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_encoder_xnnpack.pte`;
315+
const WHISPER_TINY_EN_DECODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_decoder_xnnpack.pte`;
317316

318-
const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/tokenizer.json`;
319-
const WHISPER_SMALL_EN_ENCODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_encoder_xnnpack.pte`;
320-
const WHISPER_SMALL_EN_DECODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_decoder_xnnpack.pte`;
317+
const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${NEXT_VERSION_TAG}/tokenizer.json`;
318+
const WHISPER_BASE_EN_ENCODER = `${URL_PREFIX}-whisper-base.en/${NEXT_VERSION_TAG}/xnnpack/whisper_base_en_encoder_xnnpack.pte`;
319+
const WHISPER_BASE_EN_DECODER = `${URL_PREFIX}-whisper-base.en/${NEXT_VERSION_TAG}/xnnpack/whisper_base_en_decoder_xnnpack.pte`;
321320

322-
const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/tokenizer.json`;
323-
const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_encoder_xnnpack.pte`;
324-
const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_decoder_xnnpack.pte`;
321+
const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${NEXT_VERSION_TAG}/tokenizer.json`;
322+
const WHISPER_SMALL_EN_ENCODER = `${URL_PREFIX}-whisper-small.en/${NEXT_VERSION_TAG}/xnnpack/whisper_small_en_encoder_xnnpack.pte`;
323+
const WHISPER_SMALL_EN_DECODER = `${URL_PREFIX}-whisper-small.en/${NEXT_VERSION_TAG}/xnnpack/whisper_small_en_decoder_xnnpack.pte`;
325324

326-
const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/tokenizer.json`;
327-
const WHISPER_BASE_ENCODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_encoder_xnnpack.pte`;
328-
const WHISPER_BASE_DECODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_decoder_xnnpack.pte`;
325+
const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${NEXT_VERSION_TAG}/tokenizer.json`;
326+
const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_encoder_xnnpack.pte`;
327+
const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_decoder_xnnpack.pte`;
329328

330-
const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/tokenizer.json`;
331-
const WHISPER_SMALL_ENCODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_encoder_xnnpack.pte`;
332-
const WHISPER_SMALL_DECODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_decoder_xnnpack.pte`;
329+
const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${NEXT_VERSION_TAG}/tokenizer.json`;
330+
const WHISPER_BASE_ENCODER_MODEL = `${URL_PREFIX}-whisper-base/${NEXT_VERSION_TAG}/xnnpack/whisper_base_encoder_xnnpack.pte`;
331+
const WHISPER_BASE_DECODER_MODEL = `${URL_PREFIX}-whisper-base/${NEXT_VERSION_TAG}/xnnpack/whisper_base_decoder_xnnpack.pte`;
332+
333+
const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${NEXT_VERSION_TAG}/tokenizer.json`;
334+
const WHISPER_SMALL_ENCODER_MODEL = `${URL_PREFIX}-whisper-small/${NEXT_VERSION_TAG}/xnnpack/whisper_small_encoder_xnnpack.pte`;
335+
const WHISPER_SMALL_DECODER_MODEL = `${URL_PREFIX}-whisper-small/${NEXT_VERSION_TAG}/xnnpack/whisper_small_decoder_xnnpack.pte`;
333336

334337
export const WHISPER_TINY_EN = {
335338
isMultilingual: false,
@@ -338,6 +341,13 @@ export const WHISPER_TINY_EN = {
338341
tokenizerSource: WHISPER_TINY_EN_TOKENIZER,
339342
};
340343

344+
export const WHISPER_TINY_EN_QUANTIZED = {
345+
isMultilingual: false,
346+
encoderSource: WHISPER_TINY_EN_ENCODER_QUANTIZED,
347+
decoderSource: WHISPER_TINY_EN_DECODER_QUANTIZED,
348+
tokenizerSource: WHISPER_TINY_EN_TOKENIZER,
349+
};
350+
341351
export const WHISPER_BASE_EN = {
342352
isMultilingual: false,
343353
encoderSource: WHISPER_BASE_EN_ENCODER,

0 commit comments

Comments
 (0)