Skip to content

Commit efda4fd

Browse files
committed
feat: remove stft calculation within the encoder
1 parent 83130d3 commit efda4fd

File tree

4 files changed

+23
-26
lines changed

4 files changed

+23
-26
lines changed

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
5656
}
5757

5858
std::vector<std::vector<int32_t>>
59-
BaseModel::getAllInputShapes(std::string methodName) {
59+
BaseModel::getAllInputShapes(std::string methodName) const {
6060
if (!module_) {
6161
throw std::runtime_error("Model not loaded: Cannot get all input shapes");
6262
}

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class BaseModel {
2727
void unload() noexcept;
2828
std::vector<int32_t> getInputShape(std::string method_name, int32_t index);
2929
std::vector<std::vector<int32_t>>
30-
getAllInputShapes(std::string methodName = "forward");
30+
getAllInputShapes(std::string methodName = "forward") const;
3131
std::vector<JSTensorViewOut>
3232
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
3333
Result<std::vector<EValue>> forward(const EValue &input_value) const;

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
#include "ASR.h"
55
#include "executorch/extension/tensor/tensor_ptr.h"
66
#include "rnexecutorch/data_processing/Numerical.h"
7-
#include "rnexecutorch/data_processing/dsp.h"
87
#include "rnexecutorch/data_processing/gzip.h"
8+
#include <rnexecutorch/Log.h>
99

1010
namespace rnexecutorch::models::speech_to_text::asr {
1111

@@ -37,8 +37,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const {
3737
return seq;
3838
}
3939

40-
GenerationResult ASR::generate(std::span<const float> waveform,
41-
float temperature,
40+
GenerationResult ASR::generate(std::span<float> waveform, float temperature,
4241
const DecodingOptions &options) const {
4342
std::vector<float> encoderOutput = this->encode(waveform);
4443

@@ -94,7 +93,7 @@ float ASR::getCompressionRatio(const std::string &text) const {
9493
}
9594

9695
std::vector<Segment>
97-
ASR::generateWithFallback(std::span<const float> waveform,
96+
ASR::generateWithFallback(std::span<float> waveform,
9897
const DecodingOptions &options) const {
9998
std::vector<float> temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
10099
std::vector<int32_t> bestTokens;
@@ -209,7 +208,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
209208
return wordObjs;
210209
}
211210

212-
std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
211+
std::vector<Segment> ASR::transcribe(std::span<float> waveform,
213212
const DecodingOptions &options) const {
214213
int32_t seek = 0;
215214
std::vector<Segment> results;
@@ -218,7 +217,7 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
218217
int32_t start = seek * ASR::kSamplingRate;
219218
const auto end = std::min<int32_t>(
220219
(seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
221-
std::span<const float> chunk = waveform.subspan(start, end - start);
220+
auto chunk = waveform.subspan(start, end - start);
222221

223222
if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
224223
break;
@@ -246,19 +245,12 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
246245
return results;
247246
}
248247

249-
std::vector<float> ASR::encode(std::span<const float> waveform) const {
250-
constexpr int32_t fftWindowSize = 512;
251-
constexpr int32_t stftHopLength = 160;
252-
constexpr int32_t innerDim = 256;
253-
254-
std::vector<float> preprocessedData =
255-
dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
256-
const auto numFrames =
257-
static_cast<int32_t>(preprocessedData.size()) / innerDim;
258-
std::vector<int32_t> inputShape = {numFrames, innerDim};
248+
std::vector<float> ASR::encode(std::span<float> waveform) const {
249+
auto inputShape = {static_cast<int32_t>(waveform.size())};
259250

260251
const auto modelInputTensor = executorch::extension::make_tensor_ptr(
261-
std::move(inputShape), std::move(preprocessedData));
252+
std::move(inputShape), waveform.data(),
253+
executorch::runtime::etensor::ScalarType::Float);
262254
const auto encoderResult = this->encoder->forward(modelInputTensor);
263255

264256
if (!encoderResult.ok()) {
@@ -268,7 +260,7 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
268260
}
269261

270262
const auto decoderOutputTensor = encoderResult.get().at(0).toTensor();
271-
const int32_t outputNumel = decoderOutputTensor.numel();
263+
const auto outputNumel = decoderOutputTensor.numel();
272264

273265
const float *const dataPtr = decoderOutputTensor.const_data_ptr<float>();
274266
return {dataPtr, dataPtr + outputNumel};
@@ -277,12 +269,18 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
277269
std::vector<float> ASR::decode(std::span<int32_t> tokens,
278270
std::span<float> encoderOutput) const {
279271
std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
272+
auto tokensLong = std::vector<int64_t>(tokens.begin(), tokens.end());
273+
280274
auto tokenTensor = executorch::extension::make_tensor_ptr(
281-
std::move(tokenShape), tokens.data(), ScalarType::Int);
275+
tokenShape, tokensLong.data(), ScalarType::Long);
282276

283277
const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
284278
std::vector<int32_t> encShape = {1, ASR::kNumFrames,
285279
encoderOutputSize / ASR::kNumFrames};
280+
log(LOG_LEVEL::Debug, encShape);
281+
log(LOG_LEVEL::Debug, tokenShape);
282+
log(LOG_LEVEL::Debug, this->encoder->getAllInputShapes());
283+
log(LOG_LEVEL::Debug, this->decoder->getAllInputShapes());
286284
auto encoderTensor = executorch::extension::make_tensor_ptr(
287285
std::move(encShape), encoderOutput.data(), ScalarType::Float);
288286

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class ASR {
1414
const models::BaseModel *decoder,
1515
const TokenizerModule *tokenizer);
1616
std::vector<types::Segment>
17-
transcribe(std::span<const float> waveform,
17+
transcribe(std::span<float> waveform,
1818
const types::DecodingOptions &options) const;
19-
std::vector<float> encode(std::span<const float> waveform) const;
19+
std::vector<float> encode(std::span<float> waveform) const;
2020
std::vector<float> decode(std::span<int32_t> tokens,
2121
std::span<float> encoderOutput) const;
2222

@@ -44,11 +44,10 @@ class ASR {
4444

4545
std::vector<int32_t>
4646
getInitialSequence(const types::DecodingOptions &options) const;
47-
types::GenerationResult generate(std::span<const float> waveform,
48-
float temperature,
47+
types::GenerationResult generate(std::span<float> waveform, float temperature,
4948
const types::DecodingOptions &options) const;
5049
std::vector<types::Segment>
51-
generateWithFallback(std::span<const float> waveform,
50+
generateWithFallback(std::span<float> waveform,
5251
const types::DecodingOptions &options) const;
5352
std::vector<types::Segment>
5453
calculateWordLevelTimestamps(std::span<const int32_t> tokens,

0 commit comments

Comments
 (0)