44#include " ASR.h"
55#include " executorch/extension/tensor/tensor_ptr.h"
66#include " rnexecutorch/data_processing/Numerical.h"
7- #include " rnexecutorch/data_processing/dsp.h"
87#include " rnexecutorch/data_processing/gzip.h"
8+ #include < rnexecutorch/Log.h>
99
1010namespace rnexecutorch ::models::speech_to_text::asr {
1111
@@ -37,8 +37,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const {
3737 return seq;
3838}
3939
40- GenerationResult ASR::generate (std::span<const float > waveform,
41- float temperature,
40+ GenerationResult ASR::generate (std::span<float > waveform, float temperature,
4241 const DecodingOptions &options) const {
4342 std::vector<float > encoderOutput = this ->encode (waveform);
4443
@@ -94,7 +93,7 @@ float ASR::getCompressionRatio(const std::string &text) const {
9493}
9594
9695std::vector<Segment>
97- ASR::generateWithFallback (std::span<const float > waveform,
96+ ASR::generateWithFallback (std::span<float > waveform,
9897 const DecodingOptions &options) const {
9998 std::vector<float > temperatures = {0 .0f , 0 .2f , 0 .4f , 0 .6f , 0 .8f , 1 .0f };
10099 std::vector<int32_t > bestTokens;
@@ -209,7 +208,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
209208 return wordObjs;
210209}
211210
212- std::vector<Segment> ASR::transcribe (std::span<const float > waveform,
211+ std::vector<Segment> ASR::transcribe (std::span<float > waveform,
213212 const DecodingOptions &options) const {
214213 int32_t seek = 0 ;
215214 std::vector<Segment> results;
@@ -218,7 +217,7 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
218217 int32_t start = seek * ASR::kSamplingRate ;
219218 const auto end = std::min<int32_t >(
220219 (seek + ASR::kChunkSize ) * ASR::kSamplingRate , waveform.size ());
221- std::span< const float > chunk = waveform.subspan (start, end - start);
220+ auto chunk = waveform.subspan (start, end - start);
222221
223222 if (std::cmp_less (chunk.size (), ASR::kMinChunkSamples )) {
224223 break ;
@@ -246,19 +245,12 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
246245 return results;
247246}
248247
249- std::vector<float > ASR::encode (std::span<const float > waveform) const {
250- constexpr int32_t fftWindowSize = 512 ;
251- constexpr int32_t stftHopLength = 160 ;
252- constexpr int32_t innerDim = 256 ;
253-
254- std::vector<float > preprocessedData =
255- dsp::stftFromWaveform (waveform, fftWindowSize, stftHopLength);
256- const auto numFrames =
257- static_cast <int32_t >(preprocessedData.size ()) / innerDim;
258- std::vector<int32_t > inputShape = {numFrames, innerDim};
248+ std::vector<float > ASR::encode (std::span<float > waveform) const {
249+ auto inputShape = {static_cast <int32_t >(waveform.size ())};
259250
260251 const auto modelInputTensor = executorch::extension::make_tensor_ptr (
261- std::move (inputShape), std::move (preprocessedData));
252+ std::move (inputShape), waveform.data (),
253+ executorch::runtime::etensor::ScalarType::Float);
262254 const auto encoderResult = this ->encoder ->forward (modelInputTensor);
263255
264256 if (!encoderResult.ok ()) {
@@ -268,7 +260,7 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
268260 }
269261
270262 const auto decoderOutputTensor = encoderResult.get ().at (0 ).toTensor ();
271- const int32_t outputNumel = decoderOutputTensor.numel ();
263+ const auto outputNumel = decoderOutputTensor.numel ();
272264
273265 const float *const dataPtr = decoderOutputTensor.const_data_ptr <float >();
274266 return {dataPtr, dataPtr + outputNumel};
@@ -277,12 +269,18 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
277269std::vector<float > ASR::decode (std::span<int32_t > tokens,
278270 std::span<float > encoderOutput) const {
279271 std::vector<int32_t > tokenShape = {1 , static_cast <int32_t >(tokens.size ())};
272+ auto tokensLong = std::vector<int64_t >(tokens.begin (), tokens.end ());
273+
280274 auto tokenTensor = executorch::extension::make_tensor_ptr (
281- std::move ( tokenShape), tokens .data (), ScalarType::Int );
275+ tokenShape, tokensLong .data (), ScalarType::Long );
282276
283277 const auto encoderOutputSize = static_cast <int32_t >(encoderOutput.size ());
284278 std::vector<int32_t > encShape = {1 , ASR::kNumFrames ,
285279 encoderOutputSize / ASR::kNumFrames };
280+ log (LOG_LEVEL::Debug, encShape);
281+ log (LOG_LEVEL::Debug, tokenShape);
282+ log (LOG_LEVEL::Debug, this ->encoder ->getAllInputShapes ());
283+ log (LOG_LEVEL::Debug, this ->decoder ->getAllInputShapes ());
286284 auto encoderTensor = executorch::extension::make_tensor_ptr (
287285 std::move (encShape), encoderOutput.data (), ScalarType::Float);
288286
0 commit comments