From 540c31144526003dad24c0dc88aac52eda6b0487 Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 23 Sep 2025 17:33:58 +0200 Subject: [PATCH 1/3] Introduce finalizeScoringContext function for updating hidden states --- src/Nn/LabelScorer/CombineLabelScorer.cc | 15 +++++++++++ src/Nn/LabelScorer/CombineLabelScorer.hh | 1 + .../LabelScorer/EncoderDecoderLabelScorer.cc | 4 +++ .../LabelScorer/EncoderDecoderLabelScorer.hh | 1 + src/Nn/LabelScorer/LabelScorer.cc | 4 +++ src/Nn/LabelScorer/LabelScorer.hh | 8 +++++- src/Nn/LabelScorer/ScoringContext.hh | 6 ++++- src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc | 27 ++++++++++++++++--- src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh | 4 ++- .../LexiconfreeTimesyncBeamSearch.cc | 3 +++ .../TreeTimesyncBeamSearch.cc | 3 +++ 11 files changed, 70 insertions(+), 6 deletions(-) diff --git a/src/Nn/LabelScorer/CombineLabelScorer.cc b/src/Nn/LabelScorer/CombineLabelScorer.cc index 706c6c5cf..9ebeb28b3 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.cc +++ b/src/Nn/LabelScorer/CombineLabelScorer.cc @@ -71,6 +71,21 @@ ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& requ return Core::ref(new CombineScoringContext(std::move(extScoringContexts))); } +ScoringContextRef CombineLabelScorer::finalizeScoringContext(ScoringContextRef const& context) { + auto combineContext = dynamic_cast(context.get()); + + std::vector extScoringContexts; + extScoringContexts.reserve(scaledScorers_.size()); + + auto scorerIt = scaledScorers_.begin(); + auto contextIt = combineContext->scoringContexts.begin(); + + for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) { + extScoringContexts.push_back(scorerIt->scorer->finalizeScoringContext(*contextIt)); + } + return Core::ref(new CombineScoringContext(std::move(extScoringContexts))); +} + void CombineLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { std::vector combineContexts; combineContexts.reserve(activeContexts.internalSize()); diff --git a/src/Nn/LabelScorer/CombineLabelScorer.hh b/src/Nn/LabelScorer/CombineLabelScorer.hh index a6baf6883..ef0425d64 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.hh +++ b/src/Nn/LabelScorer/CombineLabelScorer.hh @@ -48,6 +48,7 @@ public: // Combine extended ScoringContexts from all sub-scorers ScoringContextRef extendedScoringContext(Request const& request) override; + ScoringContextRef finalizeScoringContext(ScoringContextRef const& context) override; // Cleanup all sub-scorers void cleanupCaches(Core::CollapsedVector const& activeContexts) override; diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc index 257c08e1b..0a79c3a34 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc @@ -37,6 +37,10 @@ ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContext(Request cons return decoder_->extendedScoringContext(request); } +ScoringContextRef EncoderDecoderLabelScorer::finalizeScoringContext(ScoringContextRef const& context) { + return decoder_->finalizeScoringContext(context); +} + void EncoderDecoderLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { decoder_->cleanupCaches(activeContexts); } diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh index 204204aa6..08bc65283 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh @@ -48,6 +48,7 @@ public: // Get extended context from decoder component ScoringContextRef extendedScoringContext(Request const& request) override; + ScoringContextRef finalizeScoringContext(ScoringContextRef const& context) override; // Cleanup decoder component. Encoder is "self-cleaning" already in that it only stores outputs until they are // retrieved. diff --git a/src/Nn/LabelScorer/LabelScorer.cc b/src/Nn/LabelScorer/LabelScorer.cc index 010aa56a5..9fd930a3a 100644 --- a/src/Nn/LabelScorer/LabelScorer.cc +++ b/src/Nn/LabelScorer/LabelScorer.cc @@ -33,6 +33,10 @@ void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { } } +ScoringContextRef LabelScorer::finalizeScoringContext(ScoringContextRef const& context) { + return context; +} + std::optional LabelScorer::computeScoresWithTimes(std::vector const& requests) { // By default, just loop over the non-batched `computeScoreWithTime` and collect the results ScoresWithTimes result; diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index 6666d28d4..d5c21f81f 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -118,9 +118,15 @@ public: // Gets initial scoring context to use for the hypotheses in the first search step virtual ScoringContextRef getInitialScoringContext() = 0; - // Creates a copy of the context in the request that is extended using the given token and transition type + // Creates a copy of the context in the request that is extended such that the hashing + // and equality operators return the correct result but may omit expensive operations + // that do not affect the hash (e.g. hidden-state updates). virtual ScoringContextRef extendedScoringContext(Request const& request) = 0; + // Finalize the scoring context by applying remaining expensive operations (e.g. hidden-state updates) + // that don't affect the hash + virtual ScoringContextRef finalizeScoringContext(ScoringContextRef const& context); + // Given a collection of currently active contexts, this function can clean up values in any internal caches // or buffers that are saved for scoring contexts which no longer are active. virtual void cleanupCaches(Core::CollapsedVector const& activeContexts) {}; diff --git a/src/Nn/LabelScorer/ScoringContext.hh b/src/Nn/LabelScorer/ScoringContext.hh index 425aaacb1..b3f69083a 100644 --- a/src/Nn/LabelScorer/ScoringContext.hh +++ b/src/Nn/LabelScorer/ScoringContext.hh @@ -30,6 +30,8 @@ static constexpr LabelIndex invalidLabelIndex = Core::Type::max; * Empty scoring context base class */ struct ScoringContext : public Core::ReferenceCounted { + bool requiresFinalize = false; + virtual ~ScoringContext() = default; virtual bool isEqual(Core::Ref const& other) const; @@ -139,7 +141,9 @@ struct OnnxHiddenStateScoringContext : public ScoringContext { : labelSeq(), hiddenState() {} OnnxHiddenStateScoringContext(std::vector const& labelSeq, OnnxHiddenStateRef state) - : labelSeq(labelSeq), hiddenState(state) {} + : labelSeq(labelSeq), hiddenState(state) { + requiresFinalize = false; + } bool isEqual(ScoringContextRef const& other) const; size_t hash() const; diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index be5bc8179..3710b12bb 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -232,15 +232,33 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContext( std::vector newLabelSeq(history->labelSeq); newLabelSeq.push_back(request.nextToken); + auto newScoringContext = Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), history->hiddenState)); + newScoringContext->requiresFinalize = true; + + return newScoringContext; +} + +Core::Ref StatefulOnnxLabelScorer::finalizeScoringContext(ScoringContextRef const& context) { + // If this scoring context does not need finalization, just return it + if (not context->requiresFinalize) { + return context; + } + + OnnxHiddenStateScoringContextRef history(dynamic_cast(context.get())); + OnnxHiddenStateRef newHiddenState; if (not history->hiddenState) { // Sentinel start-state - newHiddenState = updatedHiddenState(computeInitialHiddenState(), request.nextToken); + verify(not history->labelSeq.empty()); + newHiddenState = updatedHiddenState(computeInitialHiddenState(), history->labelSeq.back()); } else { - newHiddenState = updatedHiddenState(history->hiddenState, request.nextToken); + newHiddenState = updatedHiddenState(history->hiddenState, history->labelSeq.back()); } - return Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), newHiddenState)); + auto newScoringContext = Core::ref(new OnnxHiddenStateScoringContext(std::move(history->labelSeq), newHiddenState)); + newScoringContext->requiresFinalize = false; + + return newScoringContext; } void StatefulOnnxLabelScorer::addInput(DataView const& input) { @@ -269,6 +287,9 @@ std::optional StatefulOnnxLabelScorer::computeScor std::unordered_set uniqueUncachedHistories; for (auto& request : requests) { + // The search algorithm is supposed to finalize all scoring contexts before using them for scoring again. + verify(not request.context->requiresFinalize); + OnnxHiddenStateScoringContextRef historyPtr(dynamic_cast(request.context.get())); if (not scoreCache_.contains(historyPtr)) { // Group by unique history diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh index 29f1777c8..7b602f46b 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh @@ -74,8 +74,10 @@ public: // If startLabelIndex is set, forward that through the state updater to obtain the start history Core::Ref getInitialScoringContext() override; - // Forward hidden-state through state-updater ONNX model + // Append the new token to the label sequence Core::Ref extendedScoringContext(LabelScorer::Request const& request) override; + // Forward hidden-state through state-updater ONNX model + Core::Ref finalizeScoringContext(ScoringContextRef const& context) override; // Add a single encoder outputs to buffer void addInput(DataView const& input) override; diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 81e173808..4c0d1a7ed 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -266,6 +266,9 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() { for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { auto& hyp = beam_[hypIndex]; + // Finalize scoring context of hypotheses from previous iteration + hyp.scoringContext = labelScorer_->finalizeScoringContext(hyp.scoringContext); + // Iterate over possible successors (all lemmas) for (auto lemmaIt = lemmas.first; lemmaIt != lemmas.second; ++lemmaIt) { const Bliss::Lemma* lemma(*lemmaIt); diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 760b5f1c1..38c915425 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -321,6 +321,9 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { auto& hyp = beam_[hypIndex]; + // Finalize scoring context of hypotheses from previous iteration + hyp.scoringContext = labelScorer_->finalizeScoringContext(hyp.scoringContext); + // Iterate over the successors of this hypothesis' current state in the tree for (const auto& successorState : stateSuccessorLookup_[hyp.currentState]) { Nn::LabelIndex tokenIdx = network_->structure.state(successorState).stateDesc.acousticModel; From 39888eff05317a7929ef503966feaabb280721ac Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 29 Sep 2025 13:28:05 +0200 Subject: [PATCH 2/3] Refactor: Handle finalization internally in LabelScorer --- src/Nn/LabelScorer/CombineLabelScorer.cc | 15 --- src/Nn/LabelScorer/CombineLabelScorer.hh | 1 - .../LabelScorer/EncoderDecoderLabelScorer.cc | 4 - .../LabelScorer/EncoderDecoderLabelScorer.hh | 1 - src/Nn/LabelScorer/LabelScorer.cc | 4 - src/Nn/LabelScorer/LabelScorer.hh | 4 - src/Nn/LabelScorer/ScoringContext.hh | 15 +-- src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc | 107 +++++++++--------- src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh | 16 +-- .../LexiconfreeTimesyncBeamSearch.cc | 3 - .../TreeTimesyncBeamSearch.cc | 3 - 11 files changed, 66 insertions(+), 107 deletions(-) diff --git a/src/Nn/LabelScorer/CombineLabelScorer.cc b/src/Nn/LabelScorer/CombineLabelScorer.cc index 9ebeb28b3..706c6c5cf 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.cc +++ b/src/Nn/LabelScorer/CombineLabelScorer.cc @@ -71,21 +71,6 @@ ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& requ return Core::ref(new CombineScoringContext(std::move(extScoringContexts))); } -ScoringContextRef CombineLabelScorer::finalizeScoringContext(ScoringContextRef const& context) { - auto combineContext = dynamic_cast(context.get()); - - std::vector extScoringContexts; - extScoringContexts.reserve(scaledScorers_.size()); - - auto scorerIt = scaledScorers_.begin(); - auto contextIt = combineContext->scoringContexts.begin(); - - for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) { - extScoringContexts.push_back(scorerIt->scorer->finalizeScoringContext(*contextIt)); - } - return Core::ref(new CombineScoringContext(std::move(extScoringContexts))); -} - void CombineLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { std::vector combineContexts; combineContexts.reserve(activeContexts.internalSize()); diff --git a/src/Nn/LabelScorer/CombineLabelScorer.hh b/src/Nn/LabelScorer/CombineLabelScorer.hh index ef0425d64..a6baf6883 100644 --- a/src/Nn/LabelScorer/CombineLabelScorer.hh +++ b/src/Nn/LabelScorer/CombineLabelScorer.hh @@ -48,7 +48,6 @@ public: // Combine extended ScoringContexts from all sub-scorers ScoringContextRef extendedScoringContext(Request const& request) override; - ScoringContextRef finalizeScoringContext(ScoringContextRef const& context) override; // Cleanup all sub-scorers void cleanupCaches(Core::CollapsedVector const& activeContexts) override; diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc index 0a79c3a34..257c08e1b 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc @@ -37,10 +37,6 @@ ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContext(Request cons return decoder_->extendedScoringContext(request); } -ScoringContextRef EncoderDecoderLabelScorer::finalizeScoringContext(ScoringContextRef const& context) { - return decoder_->finalizeScoringContext(context); -} - void EncoderDecoderLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { decoder_->cleanupCaches(activeContexts); } diff --git a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh index 08bc65283..204204aa6 100644 --- a/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh +++ b/src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh @@ -48,7 +48,6 @@ public: // Get extended context from decoder component ScoringContextRef extendedScoringContext(Request const& request) override; - ScoringContextRef finalizeScoringContext(ScoringContextRef const& context) override; // Cleanup decoder component. Encoder is "self-cleaning" already in that it only stores outputs until they are // retrieved. diff --git a/src/Nn/LabelScorer/LabelScorer.cc b/src/Nn/LabelScorer/LabelScorer.cc index 9fd930a3a..010aa56a5 100644 --- a/src/Nn/LabelScorer/LabelScorer.cc +++ b/src/Nn/LabelScorer/LabelScorer.cc @@ -33,10 +33,6 @@ void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { } } -ScoringContextRef LabelScorer::finalizeScoringContext(ScoringContextRef const& context) { - return context; -} - std::optional LabelScorer::computeScoresWithTimes(std::vector const& requests) { // By default, just loop over the non-batched `computeScoreWithTime` and collect the results ScoresWithTimes result; diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index d5c21f81f..f7bfe84f5 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -123,10 +123,6 @@ public: // that do not affect the hash (e.g. hidden-state updates). virtual ScoringContextRef extendedScoringContext(Request const& request) = 0; - // Finalize the scoring context by applying remaining expensive operations (e.g. hidden-state updates) - // that don't affect the hash - virtual ScoringContextRef finalizeScoringContext(ScoringContextRef const& context); - // Given a collection of currently active contexts, this function can clean up values in any internal caches // or buffers that are saved for scoring contexts which no longer are active. virtual void cleanupCaches(Core::CollapsedVector const& activeContexts) {}; diff --git a/src/Nn/LabelScorer/ScoringContext.hh b/src/Nn/LabelScorer/ScoringContext.hh index b3f69083a..af2adf10f 100644 --- a/src/Nn/LabelScorer/ScoringContext.hh +++ b/src/Nn/LabelScorer/ScoringContext.hh @@ -30,8 +30,6 @@ static constexpr LabelIndex invalidLabelIndex = Core::Type::max; * Empty scoring context base class */ struct ScoringContext : public Core::ReferenceCounted { - bool requiresFinalize = false; - virtual ~ScoringContext() = default; virtual bool isEqual(Core::Ref const& other) const; @@ -126,7 +124,7 @@ struct OnnxHiddenState : public Core::ReferenceCounted { } }; -typedef Core::Ref OnnxHiddenStateRef; +typedef Core::Ref OnnxHiddenStateRef; /* * Scoring context consisting of a hidden state. @@ -134,16 +132,15 @@ typedef Core::Ref OnnxHiddenStateRef; * from the same label history. */ struct OnnxHiddenStateScoringContext : public ScoringContext { - std::vector labelSeq; // Used for hashing - OnnxHiddenStateRef hiddenState; + std::vector labelSeq; // Used for hashing + mutable OnnxHiddenStateRef hiddenState; + mutable bool requiresFinalize; OnnxHiddenStateScoringContext() - : labelSeq(), hiddenState() {} + : labelSeq(), hiddenState(), requiresFinalize(false) {} OnnxHiddenStateScoringContext(std::vector const& labelSeq, OnnxHiddenStateRef state) - : labelSeq(labelSeq), hiddenState(state) { - requiresFinalize = false; - } + : labelSeq(labelSeq), hiddenState(state), requiresFinalize(false) {} bool isEqual(ScoringContextRef const& other) const; size_t hash() const; diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index 3710b12bb..3e46484b8 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -54,7 +54,7 @@ const Core::ParameterInt StatefulOnnxLabelScorer::paramMaxBatchSize( const Core::ParameterInt StatefulOnnxLabelScorer::paramMaxCachedScores( "max-cached-score-vectors", - "Maximum size of cache that maps histories to scores. This prevents memory overflow in case of very long audio segments.", + "Maximum size of cache that maps scoring contexts to scores. This prevents memory overflow in case of very long audio segments.", 1000); // Scorer only takes hidden states as input which are not part of the IO spec @@ -201,7 +201,7 @@ Core::Ref StatefulOnnxLabelScorer::getInitialScoringContex } Core::Ref StatefulOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { - OnnxHiddenStateScoringContextRef history(dynamic_cast(request.context.get())); + OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast(request.context.get())); bool updateState = false; switch (request.transitionType) { @@ -224,43 +224,21 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContext( error() << "Unknown transition type " << request.transitionType; } - // If history is not going to be modified, return the original one + // If scoring context is not going to be modified, return the original one if (not updateState) { return request.context; } - std::vector newLabelSeq(history->labelSeq); + std::vector newLabelSeq(scoringContext->labelSeq); newLabelSeq.push_back(request.nextToken); - auto newScoringContext = Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), history->hiddenState)); + // Re-use previous hidden-state but mark that finalization (i.e. hidden-state update) is required + auto newScoringContext = Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), scoringContext->hiddenState)); newScoringContext->requiresFinalize = true; return newScoringContext; } -Core::Ref StatefulOnnxLabelScorer::finalizeScoringContext(ScoringContextRef const& context) { - // If this scoring context does not need finalization, just return it - if (not context->requiresFinalize) { - return context; - } - - OnnxHiddenStateScoringContextRef history(dynamic_cast(context.get())); - - OnnxHiddenStateRef newHiddenState; - if (not history->hiddenState) { // Sentinel start-state - verify(not history->labelSeq.empty()); - newHiddenState = updatedHiddenState(computeInitialHiddenState(), history->labelSeq.back()); - } - else { - newHiddenState = updatedHiddenState(history->hiddenState, history->labelSeq.back()); - } - - auto newScoringContext = Core::ref(new OnnxHiddenStateScoringContext(std::move(history->labelSeq), newHiddenState)); - newScoringContext->requiresFinalize = false; - - return newScoringContext; -} - void StatefulOnnxLabelScorer::addInput(DataView const& input) { Precursor::addInput(input); @@ -282,44 +260,44 @@ std::optional StatefulOnnxLabelScorer::computeScor result.scores.reserve(requests.size()); /* - * Identify unique histories that still need session runs + * Identify unique scoring contexts that still need session runs */ - std::unordered_set uniqueUncachedHistories; + std::unordered_set uniqueUncachedScoringContexts; for (auto& request : requests) { - // The search algorithm is supposed to finalize all scoring contexts before using them for scoring again. - verify(not request.context->requiresFinalize); + // We need to finalize all scoring contexts before using them for scoring again. - OnnxHiddenStateScoringContextRef historyPtr(dynamic_cast(request.context.get())); - if (not scoreCache_.contains(historyPtr)) { - // Group by unique history - uniqueUncachedHistories.emplace(historyPtr); + OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast(request.context.get())); + finalizeScoringContext(scoringContext); + if (not scoreCache_.contains(scoringContext)) { + // Group by unique scoring context + uniqueUncachedScoringContexts.emplace(scoringContext); } } - std::vector historyBatch; - historyBatch.reserve(std::min(uniqueUncachedHistories.size(), maxBatchSize_)); - for (auto history : uniqueUncachedHistories) { - historyBatch.push_back(history); - if (historyBatch.size() == maxBatchSize_) { // Batch is full -> forward now - forwardBatch(historyBatch); - historyBatch.clear(); + std::vector scoringContextBatch; + scoringContextBatch.reserve(std::min(uniqueUncachedScoringContexts.size(), maxBatchSize_)); + for (auto scoringContext : uniqueUncachedScoringContexts) { + scoringContextBatch.push_back(scoringContext); + if (scoringContextBatch.size() == maxBatchSize_) { // Batch is full -> forward now + forwardBatch(scoringContextBatch); + scoringContextBatch.clear(); } } - forwardBatch(historyBatch); // Forward remaining histories + forwardBatch(scoringContextBatch); // Forward remaining scoring contexts /* * Assign from cache map to result vector */ for (const auto& request : requests) { - OnnxHiddenStateScoringContextRef history(dynamic_cast(request.context.get())); + OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast(request.context.get())); - verify(scoreCache_.contains(history)); - auto const& scores = scoreCache_.get(history)->get(); + verify(scoreCache_.contains(scoringContext)); + auto const& scores = scoreCache_.get(scoringContext)->get(); result.scores.push_back(scores.at(request.nextToken)); - result.timeframes.push_back(history->labelSeq.size()); + result.timeframes.push_back(scoringContext->labelSeq.size()); } return result; @@ -446,8 +424,25 @@ OnnxHiddenStateRef StatefulOnnxLabelScorer::updatedHiddenState(OnnxHiddenStateRe return newHiddenState; } -void StatefulOnnxLabelScorer::forwardBatch(std::vector const& historyBatch) { - if (historyBatch.empty()) { +void StatefulOnnxLabelScorer::finalizeScoringContext(OnnxHiddenStateScoringContextRef const& scoringContext) { + // If this scoring context does not need finalization, don't change it + if (not scoringContext->requiresFinalize) { + return; + } + + auto hiddenState = scoringContext->hiddenState; + + if (not hiddenState) { // Sentinel start-state + hiddenState = computeInitialHiddenState(); + } + verify(not scoringContext->labelSeq.empty()); + + scoringContext->hiddenState = updatedHiddenState(hiddenState, scoringContext->labelSeq.back()); + scoringContext->requiresFinalize = false; +} + +void StatefulOnnxLabelScorer::forwardBatch(std::vector const& scoringContextBatch) { + if (scoringContextBatch.empty()) { return; } @@ -460,11 +455,11 @@ void StatefulOnnxLabelScorer::forwardBatch(std::vector stateValues; - stateValues.reserve(historyBatch.size()); + stateValues.reserve(scoringContextBatch.size()); - for (size_t b = 0ul; b < historyBatch.size(); ++b) { - auto history = historyBatch[b]; - auto hiddenState = history->hiddenState; + for (size_t b = 0ul; b < scoringContextBatch.size(); ++b) { + auto scoringContext = scoringContextBatch[b]; + auto hiddenState = scoringContext->hiddenState; if (not hiddenState) { // Sentinel hidden-state hiddenState = computeInitialHiddenState(); } @@ -482,10 +477,10 @@ void StatefulOnnxLabelScorer::forwardBatch(std::vector scoreVec; sessionOutputs.front().get(b, scoreVec); - scoreCache_.put(historyBatch[b], std::move(scoreVec)); + scoreCache_.put(scoringContextBatch[b], std::move(scoreVec)); } } diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh index 7b602f46b..0f343ac51 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh @@ -71,13 +71,11 @@ public: void reset() override; - // If startLabelIndex is set, forward that through the state updater to obtain the start history + // If startLabelIndex is set, forward that through the state updater to obtain the start ScoringContext Core::Ref getInitialScoringContext() override; - // Append the new token to the label sequence + // Append the new token to the label sequence; does not update the hidden-state. This is only done once the scoringContext is used for scoring again. Core::Ref extendedScoringContext(LabelScorer::Request const& request) override; - // Forward hidden-state through state-updater ONNX model - Core::Ref finalizeScoringContext(ScoringContextRef const& context) override; // Add a single encoder outputs to buffer void addInput(DataView const& input) override; @@ -89,13 +87,17 @@ protected: size_t getMinActiveInputIndex(Core::CollapsedVector const& activeContexts) const override; private: - // Forward a batch of histories through the ONNX model and put the resulting scores into the score cache - void forwardBatch(std::vector const& historyBatch); + // Forward a batch of scoringContexts through the ONNX model and put the resulting scores into the score cache + void forwardBatch(std::vector const& scoringContextBatch); + // Computes new hidden state based on previous hidden state and next token through state-updater call OnnxHiddenStateRef updatedHiddenState(OnnxHiddenStateRef const& hiddenState, LabelIndex nextToken); + // Replace hidden-state in scoringContext with an updated version that includes the last label + void finalizeScoringContext(OnnxHiddenStateScoringContextRef const& scoringContext); + // Since the hidden-state matrix depends on the encoder time axis, we cannot create properly create hidden-states until all encoder states have been passed. - // So getStartHistory sets the initial hidden-state to a sentinel value (empty Ref) and when other functions such as `extendedHistory` and `getScoresWithTime` + // So getInitialScoringContext sets the initial hidden-state to a sentinel value (empty Ref) and when other functions such as `extendedScoringContext` and `getScoresWithTime` // encounter this sentinel value they call `computeInitialHiddenState` instead to get a usable hidden-state. OnnxHiddenStateRef computeInitialHiddenState(); diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 4c0d1a7ed..81e173808 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -266,9 +266,6 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() { for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { auto& hyp = beam_[hypIndex]; - // Finalize scoring context of hypotheses from previous iteration - hyp.scoringContext = labelScorer_->finalizeScoringContext(hyp.scoringContext); - // Iterate over possible successors (all lemmas) for (auto lemmaIt = lemmas.first; lemmaIt != lemmas.second; ++lemmaIt) { const Bliss::Lemma* lemma(*lemmaIt); diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 38c915425..760b5f1c1 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -321,9 +321,6 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { auto& hyp = beam_[hypIndex]; - // Finalize scoring context of hypotheses from previous iteration - hyp.scoringContext = labelScorer_->finalizeScoringContext(hyp.scoringContext); - // Iterate over the successors of this hypothesis' current state in the tree for (const auto& successorState : stateSuccessorLookup_[hyp.currentState]) { Nn::LabelIndex tokenIdx = network_->structure.state(successorState).stateDesc.acousticModel; From ad702b37169582299b5ce116e0442d3176a73c3d Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Mon, 29 Sep 2025 13:28:58 +0200 Subject: [PATCH 3/3] Revert docstring in LabelScorer --- src/Nn/LabelScorer/LabelScorer.hh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index f7bfe84f5..6666d28d4 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -118,9 +118,7 @@ public: // Gets initial scoring context to use for the hypotheses in the first search step virtual ScoringContextRef getInitialScoringContext() = 0; - // Creates a copy of the context in the request that is extended such that the hashing - // and equality operators return the correct result but may omit expensive operations - // that do not affect the hash (e.g. hidden-state updates). + // Creates a copy of the context in the request that is extended using the given token and transition type virtual ScoringContextRef extendedScoringContext(Request const& request) = 0; // Given a collection of currently active contexts, this function can clean up values in any internal caches