diff --git a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc index 33d178f29..cc0b79402 100644 --- a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc @@ -119,6 +119,7 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScore case TransitionType::BLANK_TO_LABEL: case TransitionType::LABEL_TO_LABEL: case TransitionType::INITIAL_LABEL: + case TransitionType::SENTENCE_END: pushToken = true; timeIncrement = not verticalLabelTransition_; break; diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index ed6072c0a..51b51908b 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -84,6 +84,7 @@ public: BLANK_LOOP, INITIAL_LABEL, INITIAL_BLANK, + SENTENCE_END, numTypes, // must remain at the end }; @@ -153,6 +154,7 @@ protected: {"blank-loop", BLANK_LOOP}, {"initial-label", INITIAL_LABEL}, {"initial-blank", INITIAL_BLANK}, + {"sentence-end", SENTENCE_END}, }); static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values"); }; diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index be5bc8179..7fe693550 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -218,6 +218,7 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContext( case LabelScorer::TransitionType::BLANK_TO_LABEL: case LabelScorer::TransitionType::LABEL_TO_LABEL: case LabelScorer::TransitionType::INITIAL_LABEL: + case LabelScorer::TransitionType::SENTENCE_END: updateState = true; break; default: diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 81e173808..1f5323d81 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -37,7 +37,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() : scoringContext(), currentToken(Nn::invalidLabelIndex), score(0.0), - trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} + trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))), + reachedSentenceEnd(false) {} LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( LexiconfreeTimesyncBeamSearch::LabelHypothesis const& base, @@ -46,13 +47,15 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( : scoringContext(newScoringContext), currentToken(extension.nextToken), score(extension.score), - trace() { + trace(), + reachedSentenceEnd(base.reachedSentenceEnd or extension.transitionType == Nn::LabelScorer::SENTENCE_END) { switch (extension.transitionType) { case Nn::LabelScorer::INITIAL_BLANK: case Nn::LabelScorer::INITIAL_LABEL: case Nn::LabelScorer::LABEL_TO_LABEL: case Nn::LabelScorer::LABEL_TO_BLANK: case Nn::LabelScorer::BLANK_TO_LABEL: + case Nn::LabelScorer::SENTENCE_END: trace = Core::ref(new LatticeTrace( base.trace, extension.pron, @@ -68,6 +71,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( trace->score.acoustic = extension.score; trace->time = extension.timeframe + 1; break; + default: + break; } } @@ -106,6 +111,21 @@ const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramBlankLabelIndex( "Index of the blank label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='blank'`. If not set, the search will not use blank.", Nn::invalidLabelIndex); +const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramSentenceEndLabelIndex( + "sentence-end-label-index", + "Index of the sentence end label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='sentence-end'` or `special='sentence-boundary'`. If not set, the search will not use sentence end.", + Nn::invalidLabelIndex); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramAllowBlankAfterSentenceEnd( + "allow-blank-after-sentence-end", + "blanks can still be produced after the sentence-end has been reached", + true); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramSentenceEndFallBack( + "sentence-end-fall-back", + "Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.", + true); + const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramCollapseRepeatedLabels( "collapse-repeated-labels", "Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.", @@ -126,7 +146,11 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration SearchAlgorithmV2(config), maxBeamSize_(paramMaxBeamSize(config)), scoreThreshold_(paramScoreThreshold(config)), + sentenceEndFallback_(paramSentenceEndFallBack(config)), blankLabelIndex_(paramBlankLabelIndex(config)), + allowBlankAfterSentenceEnd_(paramAllowBlankAfterSentenceEnd(config)), + sentenceEndLemma_(), + sentenceEndLabelIndex_(paramSentenceEndLabelIndex(config)), collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), logStepwiseStatistics_(paramLogStepwiseStatistics(config)), cacheCleanupInterval_(paramCacheCleanupInterval(config)), @@ -154,6 +178,12 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration if (useBlank_) { log() << "Use blank label with index " << blankLabelIndex_; } + + useSentenceEnd_ = sentenceEndLabelIndex_ != Nn::invalidLabelIndex; + if (useSentenceEnd_) { + log() << "Use sentence end label with index " << sentenceEndLabelIndex_; + } + useScorePruning_ = scoreThreshold_ != Core::Type::max; } @@ -180,6 +210,21 @@ bool LexiconfreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination } } + sentenceEndLemma_ = lexicon_->specialLemma("sentence-end"); + if (!sentenceEndLemma_) { + sentenceEndLemma_ = lexicon_->specialLemma("sentence-boundary"); + } + if (sentenceEndLemma_) { + if (sentenceEndLabelIndex_ == Nn::invalidLabelIndex) { + sentenceEndLabelIndex_ = sentenceEndLemma_->id(); + useSentenceEnd_ = true; + log() << "Use sentence-end index " << sentenceEndLabelIndex_ << " inferred from lexicon"; + } + else if (sentenceEndLabelIndex_ != static_cast(sentenceEndLemma_->id())) { + warning() << "SentenceEnd lemma exists in lexicon with id " << sentenceEndLemma_->id() << " but is overwritten by config parameter with value " << sentenceEndLabelIndex_; + } + } + reset(); return true; } @@ -214,6 +259,7 @@ void LexiconfreeTimesyncBeamSearch::finishSegment() { labelScorer_->signalNoMoreFeatures(); featureProcessingTime_.stop(); decodeManySteps(); + finalizeHypotheses(); logStatistics(); finishedSegment_ = true; } @@ -271,6 +317,14 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() { const Bliss::Lemma* lemma(*lemmaIt); Nn::LabelIndex tokenIdx = lemma->id(); + // After first sentence-end token only allow looping that sentence-end or blanks afterwards + if (hyp.reachedSentenceEnd and + not( + (collapseRepeatedLabels_ and hyp.currentToken == sentenceEndLabelIndex_ and tokenIdx == sentenceEndLabelIndex_) // sentence-end-loop + or (allowBlankAfterSentenceEnd_ and tokenIdx == blankLabelIndex_))) { // blank + continue; + } + auto transitionType = inferTransitionType(hyp.currentToken, tokenIdx); extensions_.push_back( @@ -419,13 +473,17 @@ void LexiconfreeTimesyncBeamSearch::logStatistics() const { } Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { - bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); - bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); + bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + bool nextIsSentenceEnd = (useSentenceEnd_ and nextLabel == sentenceEndLabelIndex_); if (prevLabel == Nn::invalidLabelIndex) { if (nextIsBlank) { return Nn::LabelScorer::TransitionType::INITIAL_BLANK; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::INITIAL_LABEL; } @@ -435,6 +493,9 @@ Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionTy if (nextIsBlank) { return Nn::LabelScorer::TransitionType::BLANK_LOOP; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::BLANK_TO_LABEL; } @@ -446,6 +507,9 @@ Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionTy else if (collapseRepeatedLabels_ and prevLabel == nextLabel) { return Nn::LabelScorer::TransitionType::LABEL_LOOP; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::LABEL_TO_LABEL; } @@ -513,4 +577,36 @@ void LexiconfreeTimesyncBeamSearch::recombination(std::vectortime = beam_.front().trace->time; // Retrieve the timeframe from any hyp in the old beam + newBeam_.front().trace->pronunciation = nullptr; + newBeam_.front().trace->predecessor = Core::ref(new LatticeTrace(0, {0, 0}, {})); + newBeam_.front().reachedSentenceEnd = true; + beam_.swap(newBeam_); + } + } + else { + newBeam_.swap(beam_); + } +} + } // namespace Search diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh index 15917a146..490432589 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -43,6 +43,9 @@ public: static const Core::ParameterInt paramMaxBeamSize; static const Core::ParameterFloat paramScoreThreshold; static const Core::ParameterInt paramBlankLabelIndex; + static const Core::ParameterInt paramSentenceEndLabelIndex; + static const Core::ParameterBool paramAllowBlankAfterSentenceEnd; + static const Core::ParameterBool paramSentenceEndFallBack; static const Core::ParameterBool paramCollapseRepeatedLabels; static const Core::ParameterBool paramCacheCleanupInterval; static const Core::ParameterBool paramLogStepwiseStatistics; @@ -83,10 +86,11 @@ protected: * Struct containing all information about a single hypothesis in the beam */ struct LabelHypothesis { - Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis - Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) - Score score; // Full score of hypothesis - Core::Ref trace; // Associated trace for traceback or lattice building off of hypothesis + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + Score score; // Full score of hypothesis + Core::Ref trace; // Associated trace for traceback or lattice building off of hypothesis + bool reachedSentenceEnd; // Flag whether hypothesis trace contains a sentence end emission LabelHypothesis(); LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); @@ -107,8 +111,15 @@ private: bool useScorePruning_; Score scoreThreshold_; + bool sentenceEndFallback_; + bool useBlank_; Nn::LabelIndex blankLabelIndex_; + bool allowBlankAfterSentenceEnd_; + + bool useSentenceEnd_; + Bliss::Lemma const* sentenceEndLemma_; + Nn::LabelIndex sentenceEndLabelIndex_; bool collapseRepeatedLabels_; @@ -167,6 +178,12 @@ private: * Helper function for recombination of hypotheses with the same scoring context */ void recombination(std::vector& hypotheses); + + /* + * Prune away all hypotheses that have not reached sentence end. + * If no hypotheses would survive this, either construct an empty one or keep the beam intact if sentence-end fallback is enabled. + */ + void finalizeHypotheses(); }; } // namespace Search diff --git a/src/Search/PersistentStateTree.cc b/src/Search/PersistentStateTree.cc index adcbdfb76..23cbc608f 100644 --- a/src/Search/PersistentStateTree.cc +++ b/src/Search/PersistentStateTree.cc @@ -33,7 +33,7 @@ static const Core::ParameterString paramCacheArchive( "cache archive in which the persistent state-network should be cached", "global-cache"); -static u32 formatVersion = 13; +static u32 formatVersion = 14; namespace Search { struct ConvertTree { @@ -298,7 +298,7 @@ void PersistentStateTree::write(Core::MappedArchiveWriter out) { out << coarticulatedRootStates << unpushedCoarticulatedRootStates; out << rootTransitDescriptions << pushedWordEndNodes << uncoarticulatedWordEndStates; - out << rootState << ciRootState << otherRootStates; + out << rootState << ciRootState << otherRootStates << finalStates; } bool PersistentStateTree::read(Core::MappedArchiveReader in) { @@ -307,8 +307,8 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) { /// @todo Eventually do memory-mapping - if (v != formatVersion) { - Core::Application::us()->log() << "Wrong compressed network format, need " << formatVersion << " got " << v; + if (v < 13) { + Core::Application::us()->log() << "Wrong compressed network format, need version >= 13 got " << v; return false; } @@ -333,6 +333,9 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) { in >> pushedWordEndNodes >> uncoarticulatedWordEndStates; in >> rootState >> ciRootState >> otherRootStates; + if (v >= 14) { + in >> finalStates; + } return in.good(); } diff --git a/src/Search/PersistentStateTree.hh b/src/Search/PersistentStateTree.hh index a200ce9b2..f37bee47e 100644 --- a/src/Search/PersistentStateTree.hh +++ b/src/Search/PersistentStateTree.hh @@ -108,6 +108,9 @@ public: // Other root nodes (currently used for the wordBoundaryRoot in CtcTreeBuilder) std::set otherRootStates; + // Valid nodes that the search can end in + std::set finalStates; + // The word-end exits std::vector exits; diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index d114ec623..131ae7bc1 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1222,11 +1222,17 @@ const Core::ParameterBool CtcTreeBuilder::paramForceBlank( "require a blank label between two identical labels (only works if label-loops are disabled)", true); +const Core::ParameterBool CtcTreeBuilder::paramAllowBlankAfterSentenceEnd( + "allow-blank-after-sentence-end", + "blanks can still be produced after the sentence-end has been reached", + true); + CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) : AbstractTreeBuilder(config, lexicon, acousticModel, network), labelLoop_(paramLabelLoop(config)), blankLoop_(paramBlankLoop(config)), - forceBlank_(paramForceBlank(config)) { + forceBlank_(paramForceBlank(config)), + allowBlankAfterSentenceEnd_(paramAllowBlankAfterSentenceEnd(config)) { auto iters = lexicon.phonemeInventory()->phonemes(); for (auto it = iters.first; it != iters.second; ++it) { require(not(*it)->isContextDependent()); // Context dependent labels are not supported @@ -1248,6 +1254,26 @@ CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& wordBoundaryRoot_ = createRoot(); network_.otherRootStates.insert(wordBoundaryRoot_); } + + // Create a special root for sentence-end + auto sentenceEndLemma = getSentenceEndLemma(); + if (sentenceEndLemma == nullptr or sentenceEndLemma->nPronunciations() == 0) { + if (sentenceEndLemma != nullptr) { + warning() << "Building tree without sentence-end which means it may also not be scored by the LM"; + } + + // If no sentence-end is present, any root state is a valid final state + network_.finalStates.insert(network_.rootState); + for (auto const& otherRootState : network_.otherRootStates) { + network_.finalStates.insert(otherRootState); + } + } + else { + // If sentence-end is present, the sink state is the only valid final state + sentenceEndSink_ = createRoot(); + network_.otherRootStates.insert(sentenceEndSink_); + network_.finalStates.insert(sentenceEndSink_); + } } } @@ -1261,15 +1287,19 @@ void CtcTreeBuilder::build() { addWordBoundaryStates(); } + auto sentenceEndLemma = getSentenceEndLemma(); + if (sentenceEndLemma != nullptr or sentenceEndLemma->nPronunciations() == 0) { + addSentenceEndStates(); + } + auto blankLemma = lexicon_.specialLemma("blank"); auto silenceLemma = lexicon_.specialLemma("silence"); auto iters = lexicon_.lemmaPronunciations(); // Iterate over the lemmata and add them to the tree for (auto it = iters.first; it != iters.second; ++it) { - if ((*it)->lemma() == wordBoundaryLemma) { - // The wordBoundaryLemma should be a successor of the wordBoundaryRoot_ - // This is handled separately in addWordBoundaryStates() + if ((*it)->lemma() == wordBoundaryLemma or (*it)->lemma() == sentenceEndLemma) { + // Word-boundary and sentence-end lemmas are handled separately by `addWordBoundaryStates` and `addSentenceEndStates` continue; } @@ -1439,6 +1469,36 @@ void CtcTreeBuilder::addWordBoundaryStates() { } } +void CtcTreeBuilder::addSentenceEndStates() { + auto sentenceEndLemma = getSentenceEndLemma(); + if (sentenceEndLemma == nullptr) { + return; + } + + // Add the sentence-end to the tree, starting from the root. + require(sentenceEndLemma->nPronunciations() == 1); // Sentence-end must have at least one pronunciation, even if it is empty. + auto const& sentenceEndPron = *sentenceEndLemma->pronunciations().first; + // It may be that sentenceEndLastState == root if the pronunciation has length 0. + StateId sentenceEndLastState = extendPronunciation(network_.rootState, sentenceEndPron.pronunciation()); + verify(sentenceEndLastState != 0); + + addExit(sentenceEndLastState, sentenceEndSink_, sentenceEndPron.id()); + + // Add optional blank after the sentence-end lemma + if (allowBlankAfterSentenceEnd_) { + StateId blankAfter = extendState(sentenceEndSink_, blankDesc_); + addExit(blankAfter, sentenceEndSink_, lexicon_.specialLemma("blank")->id()); + } +} + +Bliss::Lemma const* CtcTreeBuilder::getSentenceEndLemma() const { + auto sentenceEndLemma = lexicon_.specialLemma("sentence-end"); + if (sentenceEndLemma == nullptr) { + sentenceEndLemma = lexicon_.specialLemma("sentence-boundary"); + } + return sentenceEndLemma; +} + // -------------------- RnaTreeBuilder -------------------- const Core::ParameterBool RnaTreeBuilder::paramLabelLoop( diff --git a/src/Search/TreeBuilder.hh b/src/Search/TreeBuilder.hh index 78b0c829d..6c00e1b8b 100644 --- a/src/Search/TreeBuilder.hh +++ b/src/Search/TreeBuilder.hh @@ -252,6 +252,7 @@ public: static const Core::ParameterBool paramLabelLoop; static const Core::ParameterBool paramBlankLoop; static const Core::ParameterBool paramForceBlank; + static const Core::ParameterBool paramAllowBlankAfterSentenceEnd; CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); virtual ~CtcTreeBuilder() = default; @@ -265,8 +266,10 @@ protected: bool labelLoop_; bool blankLoop_; bool forceBlank_; + bool allowBlankAfterSentenceEnd_; StateId wordBoundaryRoot_; + StateId sentenceEndSink_; // Reached after emitting sentence-end with no more outgoing transitions except for blank-looping if `allowBlankAfterSentenceEnd_` is enabled Search::StateTree::StateDesc blankDesc_; Am::AllophoneStateIndex blankAllophoneStateIndex_; @@ -291,6 +294,11 @@ protected: // Build the sub-tree with the word-boundary lemma plus optional blank starting from `wordBoundaryRoot_`. void addWordBoundaryStates(); + + // Build the sub-tree with the sentence-end lemma plus optional blank starting from `sentenceEndRoot_`. + void addSentenceEndStates(); + + Bliss::Lemma const* getSentenceEndLemma() const; }; class RnaTreeBuilder : public CtcTreeBuilder { diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc index 760b5f1c1..29cead38a 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc @@ -18,13 +18,14 @@ #include #include +#include #include #include #include #include #include -#include "Search/Module.hh" -#include "Search/Traceback.hh" +#include +#include namespace Search { @@ -137,6 +138,8 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config maxWordEndBeamSize_(paramMaxWordEndBeamSize(config)), scoreThreshold_(paramScoreThreshold(config)), wordEndScoreThreshold_(paramWordEndScoreThreshold(config)), + blankLabelIndex_(Nn::invalidLabelIndex), + sentenceEndLabelIndex_(Nn::invalidLabelIndex), cacheCleanupInterval_(paramCacheCleanupInterval(config)), useBlank_(), collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), @@ -201,6 +204,7 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& std::unique_ptr builder = Search::Module::instance().createTreeBuilder(config, *lexicon_, *acousticModel_, *network_); builder->build(); + network_->dumpDotGraph("tree.dot"); if (lexicon_->specialLemma("blank")) { blankLabelIndex_ = acousticModel_->emissionIndex(acousticModel_->blankAllophoneStateIndex()); @@ -212,6 +216,23 @@ bool TreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& useBlank_ = false; } + auto const* sentenceEndLemma = lexicon_->specialLemma("sentence-end"); + if (not sentenceEndLemma) { + sentenceEndLemma = lexicon_->specialLemma("sentence-boundary"); + } + if (sentenceEndLemma and sentenceEndLemma->nPronunciations() != 0) { + auto const* pron = sentenceEndLemma->pronunciations().first->pronunciation(); + require(pron->length() == 1); + Am::Allophone allo(acousticModel_->phonology()->allophone(*pron, 0), + Am::Allophone::isInitialPhone | Am::Allophone::isFinalPhone); + Am::AllophoneStateIndex alloStateIdx = acousticModel_->allophoneStateAlphabet()->index(&allo, 0); + + sentenceEndLabelIndex_ = acousticModel_->emissionIndex(alloStateIdx); + } + else { + sentenceEndLabelIndex_ = Nn::invalidLabelIndex; + } + for (const auto& lemma : {"silence", "blank"}) { if (lexicon_->specialLemma(lemma) and (lexicon_->specialLemma(lemma)->syntacticTokenSequence()).size() != 0) { warning("Special lemma \"%s\" will be scored by the language model. To prevent the LM from scoring it, set an empty syntactic token sequence for it in the lexicon.", lemma); @@ -273,7 +294,7 @@ void TreeTimesyncBeamSearch::finishSegment() { decodeManySteps(); logStatistics(); finishedSegment_ = true; - finalizeLmScoring(); + finalizeHypotheses(); } void TreeTimesyncBeamSearch::putFeature(Nn::DataView const& feature) { @@ -478,6 +499,47 @@ bool TreeTimesyncBeamSearch::decodeStep() { clog() << Core::XmlFull("num-word-end-hyps-after-beam-pruning", wordEndHypotheses_.size()); } + /* + * Take having two exits back-to-back for the word-end hyps into account (usually for sentence-end with zero-length pronunciation after word-end) + */ + auto const origSize = wordEndHypotheses_.size(); + for (size_t hypIndex = 0ul; hypIndex < origSize; ++hypIndex) { + auto& hyp = wordEndHypotheses_[hypIndex]; + + auto exitList = exitLookup_[hyp.currentState]; + // Create one word-end hypothesis for each exit + for (const auto& exit : exitList) { + auto const* lemmaPron = lexicon_->lemmaPronunciation(exit.pronunciation); + auto const* lemma = lemmaPron->lemma(); + + ExtensionCandidate wordEndExtension{hyp.currentToken, + lemmaPron, + exit.transitState, // Start from the root node (the exit's transit state) in the next step + hyp.lmHistory, + hyp.score, + 0.0, + static_cast(currentSearchStep_), + Nn::LabelScorer::TransitionType::INITIAL_BLANK, // The transition type is irrelevant, so just use this as dummy + hypIndex}; + + auto const sts = lemma->syntacticTokenSequence(); + if (sts.size() != 0) { + require(sts.size() == 1); + auto const* st = sts.front(); + + // Add the LM score + Lm::Score lmScore = languageModel_->score(wordEndExtension.lmHistory, st); + wordEndExtension.score += lmScore; + wordEndExtension.lmScore = lmScore; + + // Extend the LM history + wordEndExtension.lmHistory = languageModel_->extendedHistory(wordEndExtension.lmHistory, st); + } + wordEndHypotheses_.push_back({hyp, wordEndExtension, hyp.scoringContext}); + } + } + recombination(wordEndHypotheses_); + beam_.swap(newBeam_); beam_.insert(beam_.end(), wordEndHypotheses_.begin(), wordEndHypotheses_.end()); @@ -570,13 +632,17 @@ void TreeTimesyncBeamSearch::logStatistics() const { } Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { - bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); - bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); + bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + bool nextIsSentenceEnd = nextLabel == sentenceEndLabelIndex_; if (prevLabel == Nn::invalidLabelIndex) { if (nextIsBlank) { return Nn::LabelScorer::TransitionType::INITIAL_BLANK; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::INITIAL_LABEL; } @@ -586,6 +652,9 @@ Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn:: if (nextIsBlank) { return Nn::LabelScorer::TransitionType::BLANK_LOOP; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::BLANK_TO_LABEL; } @@ -597,6 +666,9 @@ Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn:: else if (collapseRepeatedLabels_ and prevLabel == nextLabel) { return Nn::LabelScorer::TransitionType::LABEL_LOOP; } + else if (nextIsSentenceEnd) { + return Nn::LabelScorer::TransitionType::SENTENCE_END; + } else { return Nn::LabelScorer::TransitionType::LABEL_TO_LABEL; } @@ -717,30 +789,24 @@ void TreeTimesyncBeamSearch::createSuccessorLookups() { } } -void TreeTimesyncBeamSearch::finalizeLmScoring() { +void TreeTimesyncBeamSearch::finalizeHypotheses() { newBeam_.clear(); - for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { - auto& hyp = beam_[hypIndex]; - // Check if the hypotheses in the beam are at a root state and add the sentence-end LM score - if (hyp.currentState == network_->rootState or network_->otherRootStates.find(hyp.currentState) != network_->otherRootStates.end()) { - Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); - hyp.score += sentenceEndScore; - hyp.trace->score.lm += sentenceEndScore; + for (auto const& hyp : beam_) { + if (network_->finalStates.contains(hyp.currentState)) { newBeam_.push_back(hyp); } } - if (newBeam_.empty()) { // There was no word-end hypothesis in the beam + if (newBeam_.empty()) { // There was no valid final hypothesis in the beam warning("No active word-end hypothesis at segment end."); if (sentenceEndFallback_) { log() << "Use sentence-end fallback"; // The trace of the unfinished word keeps an empty pronunciation, only the LM score is added - for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { - auto& hyp = beam_[hypIndex]; - Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); - hyp.score += sentenceEndScore; - hyp.trace->score.lm += sentenceEndScore; + for (auto const& hyp : beam_) { newBeam_.push_back(hyp); + Lm::Score sentenceEndScore = languageModel_->sentenceEndScore(hyp.lmHistory); + newBeam_.back().score += sentenceEndScore; + newBeam_.back().trace->score.lm += sentenceEndScore; } } else { diff --git a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh index 09672d81e..8c4e63378 100644 --- a/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh +++ b/src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh @@ -118,6 +118,7 @@ private: Score scoreThreshold_; Score wordEndScoreThreshold_; Nn::LabelIndex blankLabelIndex_; + Nn::LabelIndex sentenceEndLabelIndex_; size_t cacheCleanupInterval_; bool useBlank_; @@ -198,10 +199,10 @@ private: /* * After reaching the segment end, go through the active hypotheses, only keep those - * which are at a word end (in the root state) and add the sentence end LM score. - * If no word-end hypotheses exist, use sentence-end fallback or construct an empty hypothesis + * which are final states of the search tree. + * If no such hypotheses exist, use sentence-end fallback or construct an empty hypothesis. */ - void finalizeLmScoring(); + void finalizeHypotheses(); }; } // namespace Search