diff --git a/src/Nn/LabelScorer/ScoringContext.hh b/src/Nn/LabelScorer/ScoringContext.hh index 425aaacb1..af2adf10f 100644 --- a/src/Nn/LabelScorer/ScoringContext.hh +++ b/src/Nn/LabelScorer/ScoringContext.hh @@ -124,7 +124,7 @@ struct OnnxHiddenState : public Core::ReferenceCounted { } }; -typedef Core::Ref OnnxHiddenStateRef; +typedef Core::Ref OnnxHiddenStateRef; /* * Scoring context consisting of a hidden state. @@ -132,14 +132,15 @@ typedef Core::Ref OnnxHiddenStateRef; * from the same label history. */ struct OnnxHiddenStateScoringContext : public ScoringContext { - std::vector labelSeq; // Used for hashing - OnnxHiddenStateRef hiddenState; + std::vector labelSeq; // Used for hashing + mutable OnnxHiddenStateRef hiddenState; + mutable bool requiresFinalize; OnnxHiddenStateScoringContext() - : labelSeq(), hiddenState() {} + : labelSeq(), hiddenState(), requiresFinalize(false) {} OnnxHiddenStateScoringContext(std::vector const& labelSeq, OnnxHiddenStateRef state) - : labelSeq(labelSeq), hiddenState(state) {} + : labelSeq(labelSeq), hiddenState(state), requiresFinalize(false) {} bool isEqual(ScoringContextRef const& other) const; size_t hash() const; diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index be5bc8179..3e46484b8 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -54,7 +54,7 @@ const Core::ParameterInt StatefulOnnxLabelScorer::paramMaxBatchSize( const Core::ParameterInt StatefulOnnxLabelScorer::paramMaxCachedScores( "max-cached-score-vectors", - "Maximum size of cache that maps histories to scores. This prevents memory overflow in case of very long audio segments.", + "Maximum size of cache that maps scoring contexts to scores. This prevents memory overflow in case of very long audio segments.", 1000); // Scorer only takes hidden states as input which are not part of the IO spec @@ -201,7 +201,7 @@ Core::Ref StatefulOnnxLabelScorer::getInitialScoringContex } Core::Ref StatefulOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { - OnnxHiddenStateScoringContextRef history(dynamic_cast(request.context.get())); + OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast(request.context.get())); bool updateState = false; switch (request.transitionType) { @@ -224,23 +224,19 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContext( error() << "Unknown transition type " << request.transitionType; } - // If history is not going to be modified, return the original one + // If scoring context is not going to be modified, return the original one if (not updateState) { return request.context; } - std::vector newLabelSeq(history->labelSeq); + std::vector newLabelSeq(scoringContext->labelSeq); newLabelSeq.push_back(request.nextToken); - OnnxHiddenStateRef newHiddenState; - if (not history->hiddenState) { // Sentinel start-state - newHiddenState = updatedHiddenState(computeInitialHiddenState(), request.nextToken); - } - else { - newHiddenState = updatedHiddenState(history->hiddenState, request.nextToken); - } + // Re-use previous hidden-state but mark that finalization (i.e. hidden-state update) is required + auto newScoringContext = Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), scoringContext->hiddenState)); + newScoringContext->requiresFinalize = true; - return Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), newHiddenState)); + return newScoringContext; } void StatefulOnnxLabelScorer::addInput(DataView const& input) { @@ -264,41 +260,44 @@ std::optional StatefulOnnxLabelScorer::computeScor result.scores.reserve(requests.size()); /* - * Identify unique histories that still need session runs + * Identify unique scoring contexts that still need session runs */ - std::unordered_set uniqueUncachedHistories; + std::unordered_set uniqueUncachedScoringContexts; for (auto& request : requests) { - OnnxHiddenStateScoringContextRef historyPtr(dynamic_cast(request.context.get())); - if (not scoreCache_.contains(historyPtr)) { - // Group by unique history - uniqueUncachedHistories.emplace(historyPtr); + // We need to finalize all scoring contexts before using them for scoring again. + + OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast(request.context.get())); + finalizeScoringContext(scoringContext); + if (not scoreCache_.contains(scoringContext)) { + // Group by unique scoring context + uniqueUncachedScoringContexts.emplace(scoringContext); } } - std::vector historyBatch; - historyBatch.reserve(std::min(uniqueUncachedHistories.size(), maxBatchSize_)); - for (auto history : uniqueUncachedHistories) { - historyBatch.push_back(history); - if (historyBatch.size() == maxBatchSize_) { // Batch is full -> forward now - forwardBatch(historyBatch); - historyBatch.clear(); + std::vector scoringContextBatch; + scoringContextBatch.reserve(std::min(uniqueUncachedScoringContexts.size(), maxBatchSize_)); + for (auto scoringContext : uniqueUncachedScoringContexts) { + scoringContextBatch.push_back(scoringContext); + if (scoringContextBatch.size() == maxBatchSize_) { // Batch is full -> forward now + forwardBatch(scoringContextBatch); + scoringContextBatch.clear(); } } - forwardBatch(historyBatch); // Forward remaining histories + forwardBatch(scoringContextBatch); // Forward remaining scoring contexts /* * Assign from cache map to result vector */ for (const auto& request : requests) { - OnnxHiddenStateScoringContextRef history(dynamic_cast(request.context.get())); + OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast(request.context.get())); - verify(scoreCache_.contains(history)); - auto const& scores = scoreCache_.get(history)->get(); + verify(scoreCache_.contains(scoringContext)); + auto const& scores = scoreCache_.get(scoringContext)->get(); result.scores.push_back(scores.at(request.nextToken)); - result.timeframes.push_back(history->labelSeq.size()); + result.timeframes.push_back(scoringContext->labelSeq.size()); } return result; @@ -425,8 +424,25 @@ OnnxHiddenStateRef StatefulOnnxLabelScorer::updatedHiddenState(OnnxHiddenStateRe return newHiddenState; } -void StatefulOnnxLabelScorer::forwardBatch(std::vector const& historyBatch) { - if (historyBatch.empty()) { +void StatefulOnnxLabelScorer::finalizeScoringContext(OnnxHiddenStateScoringContextRef const& scoringContext) { + // If this scoring context does not need finalization, don't change it + if (not scoringContext->requiresFinalize) { + return; + } + + auto hiddenState = scoringContext->hiddenState; + + if (not hiddenState) { // Sentinel start-state + hiddenState = computeInitialHiddenState(); + } + verify(not scoringContext->labelSeq.empty()); + + scoringContext->hiddenState = updatedHiddenState(hiddenState, scoringContext->labelSeq.back()); + scoringContext->requiresFinalize = false; +} + +void StatefulOnnxLabelScorer::forwardBatch(std::vector const& scoringContextBatch) { + if (scoringContextBatch.empty()) { return; } @@ -439,11 +455,11 @@ void StatefulOnnxLabelScorer::forwardBatch(std::vector stateValues; - stateValues.reserve(historyBatch.size()); + stateValues.reserve(scoringContextBatch.size()); - for (size_t b = 0ul; b < historyBatch.size(); ++b) { - auto history = historyBatch[b]; - auto hiddenState = history->hiddenState; + for (size_t b = 0ul; b < scoringContextBatch.size(); ++b) { + auto scoringContext = scoringContextBatch[b]; + auto hiddenState = scoringContext->hiddenState; if (not hiddenState) { // Sentinel hidden-state hiddenState = computeInitialHiddenState(); } @@ -461,10 +477,10 @@ void StatefulOnnxLabelScorer::forwardBatch(std::vector scoreVec; sessionOutputs.front().get(b, scoreVec); - scoreCache_.put(historyBatch[b], std::move(scoreVec)); + scoreCache_.put(scoringContextBatch[b], std::move(scoreVec)); } } diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh index 29f1777c8..0f343ac51 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh @@ -71,10 +71,10 @@ public: void reset() override; - // If startLabelIndex is set, forward that through the state updater to obtain the start history + // If startLabelIndex is set, forward that through the state updater to obtain the start ScoringContext Core::Ref getInitialScoringContext() override; - // Forward hidden-state through state-updater ONNX model + // Append the new token to the label sequence; does not update the hidden-state. This is only done once the scoringContext is used for scoring again. Core::Ref extendedScoringContext(LabelScorer::Request const& request) override; // Add a single encoder outputs to buffer @@ -87,13 +87,17 @@ protected: size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; private: - // Forward a batch of histories through the ONNX model and put the resulting scores into the score cache - void forwardBatch(std::vector const& historyBatch); + // Forward a batch of scoringContexts through the ONNX model and put the resulting scores into the score cache + void forwardBatch(std::vector const& scoringContextBatch); + // Computes new hidden state based on previous hidden state and next token through state-updater call OnnxHiddenStateRef updatedHiddenState(OnnxHiddenStateRef const& hiddenState, LabelIndex nextToken); + // Replace hidden-state in scoringContext with an updated version that includes the last label + void finalizeScoringContext(OnnxHiddenStateScoringContextRef const& scoringContext); + // Since the hidden-state matrix depends on the encoder time axis, we cannot create properly create hidden-states until all encoder states have been passed. - // So getStartHistory sets the initial hidden-state to a sentinel value (empty Ref) and when other functions such as `extendedHistory` and `getScoresWithTime` + // So getInitialScoringContext sets the initial hidden-state to a sentinel value (empty Ref) and when other functions such as `extendedScoringContext` and `getScoresWithTime` // encounter this sentinel value they call `computeInitialHiddenState` instead to get a usable hidden-state. OnnxHiddenStateRef computeInitialHiddenState();