diff --git a/src/Search/LanguageModelLookahead.cc b/src/Search/LanguageModelLookahead.cc index 5044991d6..23edda2ac 100644 --- a/src/Search/LanguageModelLookahead.cc +++ b/src/Search/LanguageModelLookahead.cc @@ -649,6 +649,8 @@ void LanguageModelLookahead::ConstructionTree::build(HMMStateNetwork const& for (HMMStateNetwork::SuccessorIterator target = tree_.successors(node); target; ++target) { if (not target.isLabel()) { + if (*target == node) + continue; build(*target, depth + 1); successors.push_back(*target); } @@ -743,7 +745,7 @@ void LanguageModelLookahead::ConstructionTree::build(HMMStateNetwork const& collected[node] = -2; for (HMMStateNetwork::SuccessorIterator edges = tree_.successors(node); edges; ++edges) { - if (not edges.isLabel()) { + if (not edges.isLabel() and *edges != node) { int depth2 = collectTopologicalStates(*edges, depth + 1, topologicalStates, collected); if (depth2 - 1 < depth) { depth = depth2 - 1; diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 240d63b51..610c2ccdd 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include "Search/Module.hh" @@ -38,8 +40,12 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), currentToken(Nn::invalidLabelIndex), currentState(invalidTreeNodeIndex), + lookahead(), lmHistory(), + lookaheadHistory(), + fullLookaheadHistory(), score(0.0), + lookaheadScore(0.0), trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( @@ -49,8 +55,12 @@ TreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( : scoringContext(newScoringContext), currentToken(extension.nextToken), currentState(extension.state), + lookahead(extension.lookahead), lmHistory(extension.lmHistory), + lookaheadHistory(extension.lookaheadHistory), + fullLookaheadHistory(extension.fullLookaheadHistory), score(extension.score), + lookaheadScore(extension.lmScore), trace(base.trace) { if (extension.pron != nullptr) { // Word-end hypothesis -> update base trace and start a new trace for the next word auto completedTrace = Core::ref(new LatticeTrace(*base.trace)); @@ -115,6 +125,21 @@ const Core::ParameterBool TreeTimesyncBeamSearch::paramCollapseRepeatedLabels( "Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.", false); +const Core::ParameterBool TreeTimesyncBeamSearch::paramLmLookahead( + "lm-lookahead", + "Enable language model lookahead.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSeparateLookaheadLm( + "separate-lookahead-lm", + "Use a separate LM for lookahead.", + false); + +const Core::ParameterBool TreeTimesyncBeamSearch::paramSparseLmLookAhead( + "sparse-lm-lookahead", + "Use sparse n-gram LM lookahead.", + true); + const Core::ParameterBool TreeTimesyncBeamSearch::paramSentenceEndFallBack( "sentence-end-fall-back", "Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.", @@ -140,6 +165,9 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config cacheCleanupInterval_(paramCacheCleanupInterval(config)), useBlank_(), collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), + enableLmLookahead_(paramLmLookahead(config)), + separateLookaheadLm_(paramSeparateLookaheadLm(config)), + sparseLmLookahead_(paramSparseLmLookAhead(config)), sentenceEndFallback_(paramSentenceEndFallBack(config)), logStepwiseStatistics_(paramLogStepwiseStatistics(config)), labelScorer_(), @@ -215,6 +243,34 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& // Create look-ups for state successors and exits of each state createSuccessorLookups(); + // Set lookahead LM + if (enableLmLookahead_) { + if (separateLookaheadLm_) { + log() << "Use separate lookahead LM"; + lookaheadLm_ = Lm::Module::instance().createScaledLanguageModel(select("lm-lookahead"), lexicon_); + } + else if (languageModel_->lookaheadLanguageModel().get() != nullptr) { + lookaheadLm_ = Core::Ref(new Lm::LanguageModelScaling(select("lookahead-lm"), + Core::Ref(const_cast(languageModel_->lookaheadLanguageModel().get())))); + } + else { + lookaheadLm_ = languageModel_; + } + + if (sparseLmLookahead_ && !dynamic_cast(lookaheadLm_->unscaled().get())) { + warning() << "Not using sparse LM lookahead, because the LM is not a backing-off LM."; + sparseLmLookahead_ = false; + } + + lmLookahead_ = new LanguageModelLookahead(Core::Configuration(config, "lm-lookahead"), + modelCombination.pronunciationScale(), + lookaheadLm_, + network_->structure, + network_->rootState, + network_->exits, + acousticModel_); + } + reset(); // Create global cache @@ -240,6 +296,11 @@ void TreeTimesyncBeamSearch::reset() { beam_.front().currentState = network_->rootState; beam_.front().lmHistory = languageModel_->startHistory(); + if (enableLmLookahead_) { + beam_.front().lookaheadHistory = lookaheadLm_->startHistory(); + beam_.front().fullLookaheadHistory = lookaheadLm_->startHistory(); + } + currentSearchStep_ = 0ul; finishedSegment_ = false; @@ -330,7 +391,10 @@ bool TreeTimesyncBeamSearch::decodeStep() { {tokenIdx, nullptr, successorState, + hyp.lookahead, hyp.lmHistory, + hyp.lookaheadHistory, + hyp.fullLookaheadHistory, hyp.score, 0.0, 0, @@ -355,6 +419,14 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t requestIdx = 0ul; requestIdx < extensions_.size(); ++requestIdx) { extensions_[requestIdx].score += result->scores[requestIdx]; extensions_[requestIdx].timeframe = result->timeframes[requestIdx]; + + // Add the LM lookahead score to the extensions' scores for pruning + // Make sure not to calculate the lookahead score for the blank lemma which is reachable from the root + if (enableLmLookahead_ and not(beam_[extensions_[requestIdx].baseHypIndex].currentState == network_->rootState and extensions_[requestIdx].nextToken == blankLabelIndex_)) { + Score lookaheadScore = getLmLookaheadScore(extensions_[requestIdx]); + extensions_[requestIdx].lmScore = lookaheadScore; + extensions_[requestIdx].score += lookaheadScore; + } } if (logStepwiseStatistics_) { @@ -404,6 +476,12 @@ bool TreeTimesyncBeamSearch::decodeStep() { for (size_t hypIndex = 0ul; hypIndex < newBeam_.size(); ++hypIndex) { auto& hyp = newBeam_[hypIndex]; + if (enableLmLookahead_) { + // Subtract the LM lookahead score again + hyp.score -= hyp.lookaheadScore; + hyp.lookaheadScore = 0; + } + std::vector exitList = exitLookup_[hyp.currentState]; if (not exitList.empty()) { // Create one word-end hypothesis for each exit @@ -414,7 +492,10 @@ bool TreeTimesyncBeamSearch::decodeStep() { ExtensionCandidate wordEndExtension{hyp.currentToken, lemmaPron, exit.transitState, // Start from the root node (the exit's transit state) in the next step + hyp.lookahead, hyp.lmHistory, + hyp.lookaheadHistory, + hyp.fullLookaheadHistory, hyp.score, 0.0, static_cast(currentSearchStep_), @@ -444,7 +525,7 @@ bool TreeTimesyncBeamSearch::decodeStep() { clog() << Core::XmlFull("num-word-end-hyps-after-score-pruning", extensions_.size()); } - // Create new word-end label hypotheses from word-end extension candidates and update the LM history + // Create new word-end label hypotheses from word-end extension candidates, update the LM history and prepare the new lookahead if its history has changed wordEndHypotheses_.clear(); for (auto& extension : extensions_) { const Bliss::Lemma* lemma = extension.pron->lemma(); @@ -452,6 +533,16 @@ bool TreeTimesyncBeamSearch::decodeStep() { const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence(); const Bliss::SyntacticToken* st = sts.front(); extension.lmHistory = languageModel_->extendedHistory(extension.lmHistory, st); + + if (enableLmLookahead_) { + Lm::History newLookaheadHistory = lookaheadLm_->extendedHistory(extension.fullLookaheadHistory, st); + + if (!(newLookaheadHistory == extension.lookaheadHistory)) { + getLmLookahead(extension.lookahead, newLookaheadHistory); + extension.lookaheadHistory = newLookaheadHistory; + extension.fullLookaheadHistory = newLookaheadHistory; + } + } } auto const& baseHyp = newBeam_[extension.baseHypIndex]; @@ -689,6 +780,41 @@ void TreeTimesyncBeamSearch::recombination(std::vectorgetLookahead(history); + lmLookahead_->fill(lookahead, sparseLmLookahead_); +} + +Score TreeTimesyncBeamSearch::getLmLookaheadScore(TreeTimesyncBeamSearch::ExtensionCandidate& extension) { + if (!extension.lookahead) { + getLmLookahead(extension.lookahead, extension.lookaheadHistory); + } + + Score lookaheadScore = 0; + bool scoreFound = false; + do { + if (extension.lookahead->isSparse()) { // Non-sparse lookahead + auto lookaheadHash = lmLookahead_->lookaheadHash(extension.state); + scoreFound = extension.lookahead->getScoreForLookAheadHashSparse(lookaheadHash, lookaheadScore); + } + else { // Sparse lookahead + auto lookaheadId = lmLookahead_->lookaheadId(extension.state); + lookaheadScore = extension.lookahead->scoreForLookAheadIdNormal(lookaheadId); + scoreFound = true; + } + + if (!scoreFound) { // No lookahead table entry, use back-off + const Lm::BackingOffLm* lm = dynamic_cast(lookaheadLm_->unscaled().get()); + lookaheadScore += lm->getBackOffScore(extension.lookaheadHistory); + // Reduce the history and retrieve the corresponding lookahead table + extension.lookaheadHistory = lm->reducedHistory(extension.lookaheadHistory, lm->historyLength(extension.lookaheadHistory) - 1); + getLmLookahead(extension.lookahead, extension.lookaheadHistory); + } + } while (!scoreFound); + + return lookaheadScore; +} + void TreeTimesyncBeamSearch::createSuccessorLookups() { stateSuccessorLookup_.resize(network_->structure.stateCount()); exitLookup_.resize(network_->structure.stateCount()); @@ -746,4 +872,4 @@ void TreeTimesyncBeamSearch::finalizeLmScoring() { beam_.swap(newBeam_); } -} // namespace Search +} // namespace Search \ No newline at end of file diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index 09672d81e..c7b14afc2 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -26,6 +26,7 @@ #include #include #include +#include "Search/LanguageModelLookahead.hh" namespace Search { @@ -33,6 +34,7 @@ namespace Search { * Simple time synchronous beam search algorithm on a search tree built by a TreeBuilder. * At a word end, a language model score is added to the hypothesis score, * if no language model should be used, the LM-scale has to be set to 0.0. + * Full or sparse language model lookahead can optionally be used with the same or with a separate LM. * Supports global or separate pruning of within-word and word-end hypotheses * by max beam-size and by score difference to the best hypothesis. * Uses a LabelScorer to context initialization/extension and scoring. @@ -48,6 +50,9 @@ public: static const Core::ParameterFloat paramScoreThreshold; static const Core::ParameterFloat paramWordEndScoreThreshold; static const Core::ParameterBool paramCollapseRepeatedLabels; + static const Core::ParameterBool paramLmLookahead; + static const Core::ParameterBool paramSeparateLookaheadLm; + static const Core::ParameterBool paramSparseLmLookAhead; static const Core::ParameterBool paramSentenceEndFallBack; static const Core::ParameterBool paramLogStepwiseStatistics; static const Core::ParameterBool paramCacheCleanupInterval; @@ -73,15 +78,18 @@ protected: * Possible extension for some label hypothesis in the beam */ struct ExtensionCandidate { - Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with - const Bliss::LemmaPronunciation* pron; // Pronunciation of the lemma if we are at a word end - StateId state; // State in the search tree of this extension - Lm::History lmHistory; // LM history of the hypothesis, possibly extended at a word end - Score score; // Would-be total score of the full hypothesis after extension (incl. LM score) - Score lmScore; // Would-be LM score of a word-end hypothesis after extension - Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback - Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` - size_t baseHypIndex; // Index of base hypothesis in global beam + Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with + const Bliss::LemmaPronunciation* pron; // Pronunciation of the lemma if we are at a word end + StateId state; // State in the search tree of this extension + LanguageModelLookahead::ContextLookaheadReference lookahead; // LM-lookahead table, possibly updated at a word end + Lm::History lmHistory; // LM history of the hypothesis, possibly extended at a word end + Lm::History lookaheadHistory; // LM history used for the lookahead, may be reduced + Lm::History fullLookaheadHistory; // The full/unreduced LM history for the lookahead which will be expanded at a word end + Score score; // Would-be total score of the full hypothesis after extension (incl. LM score) + Score lmScore; // Would-be LM score of a word-end hypothesis after extension + Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback + Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` + size_t baseHypIndex; // Index of base hypothesis in global beam bool operator<(ExtensionCandidate const& other) { return score < other.score; @@ -92,12 +100,16 @@ protected: * Struct containing all information about a single hypothesis in the beam */ struct LabelHypothesis { - Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis - Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) - StateId currentState; // Current state in the search tree - Lm::History lmHistory; // Language model history - Score score; // Full score of the hypothesis - Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + StateId currentState; // Current state in the search tree + LanguageModelLookahead::ContextLookaheadReference lookahead; // LM-lookahead table + Lm::History lmHistory; // Language model history + Lm::History lookaheadHistory; // LM history used for the lookahead, may be reduced + Lm::History fullLookaheadHistory; // The full/unreduced LM history for the lookahead + Score score; // Full score of the hypothesis + Score lookaheadScore; // LM-lookahead score + Core::Ref trace; // Associated trace for traceback or lattice building of hypothesis LabelHypothesis(); LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); @@ -130,8 +142,14 @@ private: Core::Ref network_; Core::Ref acousticModel_; Core::Ref languageModel_; + Core::Ref lookaheadLm_; Core::Channel debugChannel_; + bool enableLmLookahead_; + bool separateLookaheadLm_; + bool sparseLmLookahead_; + LanguageModelLookahead* lmLookahead_; + // Pre-allocated intermediate vectors std::vector extensions_; std::vector beam_; @@ -187,6 +205,16 @@ private: */ void recombination(std::vector& hypotheses); + /* + * Retrieve the LM lookahead for the given history from cache or compute and cache it if missing + */ + void getLmLookahead(LanguageModelLookahead::ContextLookaheadReference& lookahead, Lm::History history); + + /* + * Compute the sparse or non-sparse LM lookahead score for an extension's state and history, with back-off if needed + */ + Score getLmLookaheadScore(TreeTimesyncBeamSearch::ExtensionCandidate& extension); + /* * Precompute information about the successor structure of each state in the search tree * to avoid repeated computation during the decode steps @@ -206,4 +234,4 @@ private: } // namespace Search -#endif // TREE_TIMESYNC_BEAM_SEARCH_HH +#endif // TREE_TIMESYNC_BEAM_SEARCH_HH \ No newline at end of file