@@ -145,10 +145,9 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config
145145 labelScorer_(),
146146 debugChannel_(config, " debug" ),
147147 extensions_(),
148- withinWordExtensions_(),
149- wordEndExtensions_(),
150148 beam_(),
151149 newBeam_(),
150+ wordEndHypotheses_(),
152151 requests_(),
153152 recombinedHypotheses_(),
154153 currentSearchStep_(0ul ),
@@ -158,10 +157,13 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config
158157 scoringTime_(),
159158 contextExtensionTime_(),
160159 numHypsAfterScorePruning_(" num-hyps-after-score-pruning" ),
160+ numHypsAfterRecombination_(" num-hyps-after-recombination" ),
161161 numHypsAfterBeamPruning_(" num-hyps-after-beam-pruning" ),
162162 numWordEndHypsAfterScorePruning_(" num-word-end-hyps-after-score-pruning" ),
163+ numWordEndHypsAfterRecombination_(" num-word-end-hyps-after-recombination" ),
163164 numWordEndHypsAfterBeamPruning_(" num-word-end-hyps-after-beam-pruning" ),
164- numActiveHyps_(" num-active-hyps" ) {
165+ numActiveHyps_(" num-active-hyps" ),
166+ numActiveTrees_(" num-active-trees" ) {
165167 if (scoreThreshold_ == Core::Type<Score>::max and wordEndScoreThreshold_ != Core::Type<Score>::max) {
166168 error () << " Word-end score-threshold which is relative to the score-threshold is set, but score-threshold is not set" ;
167169 }
@@ -359,111 +361,133 @@ bool TreeTimesyncBeamSearch::decodeStep() {
359361 clog () << Core::XmlFull (" num-hyps-after-score-pruning" , extensions_.size ());
360362 }
361363
362- beamSizePruning (extensions_, maxBeamSize_);
363- numHypsAfterBeamPruning_ += extensions_.size ();
364+ // Create new label hypotheses from extension candidates
365+ newBeam_.clear ();
366+ for (auto const & extension : extensions_) {
367+ auto const & baseHyp = beam_[extension.baseHypIndex ];
368+
369+ auto newScoringContext = labelScorer_->extendedScoringContext (
370+ {baseHyp.scoringContext ,
371+ extension.nextToken ,
372+ extension.transitionType });
373+
374+ newBeam_.push_back ({baseHyp, extension, newScoringContext});
375+ }
376+
377+ // For all hypotheses at the same state and with the same scoring context and LM history
378+ // keep only the best since they will all develop in the same way
379+ recombination (newBeam_);
380+ numHypsAfterRecombination_ += newBeam_.size ();
381+ if (logStepwiseStatistics_) {
382+ clog () << Core::XmlFull (" num-hyps-after-recombination" , newBeam_.size ());
383+ }
384+
385+ beamSizePruning (newBeam_, maxBeamSize_);
386+ numHypsAfterBeamPruning_ += newBeam_.size ();
364387 if (logStepwiseStatistics_) {
365- clog () << Core::XmlFull (" num-hyps-after-beam-pruning" , extensions_ .size ());
388+ clog () << Core::XmlFull (" num-hyps-after-beam-pruning" , newBeam_ .size ());
366389 }
367390
368391 /*
369- * Expand extensions to word-end hypotheses and incorporate the language model
392+ * Expand hypotheses to word-end hypotheses and incorporate the language model
370393 */
371- withinWordExtensions_.clear ();
372- wordEndExtensions_.clear ();
373- for (const auto & extension : extensions_) {
374- // If there is at least one state successor, keep it as within-word hypothesis
375- if (not stateSuccessorLookup_[extension.state ].empty ()) {
376- withinWordExtensions_.push_back (extension);
377- }
378- std::vector<PersistentStateTree::Exit> exitList = exitLookup_[extension.state ];
394+ extensions_.clear ();
395+ for (size_t hypIndex = 0ul ; hypIndex < newBeam_.size (); ++hypIndex) {
396+ auto & hyp = newBeam_[hypIndex];
397+
398+ std::vector<PersistentStateTree::Exit> exitList = exitLookup_[hyp.currentState ];
379399 if (not exitList.empty ()) {
380400 // Create one word-end hypothesis for each exit
381401 for (const auto & exit : exitList) {
382- ExtensionCandidate wordEndExtension (extension);
383402 const Bliss::LemmaPronunciation* lemmaPron = lexicon_->lemmaPronunciation (exit.pronunciation );
384403 const Bliss::Lemma* lemma = lemmaPron->lemma ();
385404
386- // Start from the root node (the exit's transit state) in the next step
387- wordEndExtension.state = exit.transitState ;
388- wordEndExtension.pron = lemmaPron;
405+ ExtensionCandidate wordEndExtension{hyp.currentToken ,
406+ lemmaPron,
407+ exit.transitState , // Start from the root node (the exit's transit state) in the next step
408+ hyp.lmHistory ,
409+ hyp.score ,
410+ 0.0 ,
411+ static_cast <TimeframeIndex>(currentSearchStep_),
412+ Nn::LabelScorer::TransitionType::INITIAL_BLANK, // The transition type is irrelevant, so just use this as dummy
413+ hypIndex};
389414
390415 if (lemma != lexicon_->specialLemma (" blank" ) and lemma != lexicon_->specialLemma (" silence" )) {
391416 const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence ();
392417 const Bliss::SyntacticToken* st = sts.front ();
393418
394- // Add the LM score and update the LM history
419+ // Add the LM score
395420 Lm::Score lmScore = languageModel_->score (wordEndExtension.lmHistory , st);
396421 wordEndExtension.score += lmScore;
397- wordEndExtension.lmScore = lmScore;
398- wordEndExtension.lmHistory = languageModel_->extendedHistory (wordEndExtension.lmHistory , st);
422+ wordEndExtension.lmScore = lmScore;
399423 }
400- wordEndExtensions_ .push_back (wordEndExtension);
424+ extensions_ .push_back (wordEndExtension);
401425 }
402426 }
403427 }
404428
405429 /*
406- * Prune set of word-end hypotheses by max beam size and possibly also by score.
430+ * Prune set of word-end extensions by max beam size and possibly also by score.
407431 */
408- scorePruning (wordEndExtensions_ , wordEndScoreThreshold_);
409- numWordEndHypsAfterScorePruning_ += wordEndExtensions_ .size ();
432+ scorePruning (extensions_ , wordEndScoreThreshold_);
433+ numWordEndHypsAfterScorePruning_ += extensions_ .size ();
410434 if (logStepwiseStatistics_) {
411- clog () << Core::XmlFull (" num-word-end-hyps-after-score-pruning" , wordEndExtensions_ .size ());
435+ clog () << Core::XmlFull (" num-word-end-hyps-after-score-pruning" , extensions_ .size ());
412436 }
413437
414- beamSizePruning (wordEndExtensions_, maxWordEndBeamSize_);
415- numWordEndHypsAfterBeamPruning_ += wordEndExtensions_.size ();
416- if (logStepwiseStatistics_) {
417- clog () << Core::XmlFull (" num-word-end-hyps-after-beam-pruning" , wordEndExtensions_.size ());
418- }
419-
420- /*
421- * Create new beam from surviving extensions.
422- */
423- newBeam_.clear ();
424- for (auto const & extension : withinWordExtensions_) {
425- auto const & baseHyp = beam_[extension.baseHypIndex ];
426-
427- auto newScoringContext = labelScorer_->extendedScoringContext (
428- {baseHyp.scoringContext ,
429- extension.nextToken ,
430- extension.transitionType });
438+ // Create new word-end label hypotheses from word-end extension candidates and update the LM history
439+ wordEndHypotheses_.clear ();
440+ for (auto & extension : extensions_) {
441+ const Bliss::Lemma* lemma = extension.pron ->lemma ();
442+ if (lemma != lexicon_->specialLemma (" blank" ) and lemma != lexicon_->specialLemma (" silence" )) {
443+ const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence ();
444+ const Bliss::SyntacticToken* st = sts.front ();
445+ extension.lmHistory = languageModel_->extendedHistory (extension.lmHistory , st);
446+ }
431447
432- newBeam_.push_back ({baseHyp, extension, newScoringContext});
448+ auto const & baseHyp = newBeam_[extension.baseHypIndex ];
449+ wordEndHypotheses_.push_back ({baseHyp, extension, baseHyp.scoringContext });
433450 }
434- for (auto const & extension : wordEndExtensions_) {
435- auto const & baseHyp = beam_[extension.baseHypIndex ];
436451
437- auto newScoringContext = labelScorer_->extendedScoringContext (
438- {baseHyp.scoringContext ,
439- extension.nextToken ,
440- extension.transitionType });
452+ recombination (wordEndHypotheses_);
453+ numWordEndHypsAfterRecombination_ += wordEndHypotheses_.size ();
454+ if (logStepwiseStatistics_) {
455+ clog () << Core::XmlFull (" num-word-end-hyps-after-recombination" , wordEndHypotheses_.size ());
456+ }
441457
442- newBeam_.push_back ({baseHyp, extension, newScoringContext});
458+ beamSizePruning (wordEndHypotheses_, maxWordEndBeamSize_);
459+ numWordEndHypsAfterBeamPruning_ += wordEndHypotheses_.size ();
460+ if (logStepwiseStatistics_) {
461+ clog () << Core::XmlFull (" num-word-end-hyps-after-beam-pruning" , wordEndHypotheses_.size ());
443462 }
444463
445- /*
446- * For all hypotheses at the same state and with the same scoring context and LM history
447- * keep only the best since they will all develop in the same way.
448- */
449- recombination (newBeam_);
450- numActiveHyps_ += newBeam_.size ();
464+ beam_.swap (newBeam_);
465+ beam_.insert (beam_.end (), wordEndHypotheses_.begin (), wordEndHypotheses_.end ());
466+
467+ numActiveHyps_ += beam_.size ();
451468
452469 /*
453- * Clean up label scorer caches.
470+ * Clean up label scorer caches and calculate number of active trees
454471 */
455- if (++currentSearchStep_ % cacheCleanupInterval_ == 0 ) {
456- Core::CollapsedVector<Nn::ScoringContextRef> activeContexts;
457- for (auto const & hyp : newBeam_) {
458- activeContexts.push_back (hyp.scoringContext );
472+ Core::CollapsedVector<Nn::ScoringContextRef> activeContexts;
473+ std::vector<Lm::History> seenHistories;
474+ for (auto const & hyp : beam_) {
475+ activeContexts.push_back (hyp.scoringContext );
476+ if (std::find (seenHistories.begin (), seenHistories.end (), hyp.lmHistory ) == seenHistories.end ()) {
477+ seenHistories.push_back (hyp.lmHistory );
459478 }
479+ }
480+ if (++currentSearchStep_ % cacheCleanupInterval_ == 0 ) {
460481 labelScorer_->cleanupCaches (activeContexts);
461482 }
483+ numActiveTrees_ += seenHistories.size ();
484+ if (logStepwiseStatistics_) {
485+ clog () << Core::XmlFull (" num-active-trees" , seenHistories.size ());
486+ }
462487
463488 /*
464- * Log statistics about the new beam after this step .
489+ * Log statistics about the new beam.
465490 */
466- beam_.swap (newBeam_);
467491
468492 if (debugChannel_.isOpen ()) {
469493 std::stringstream ss;
@@ -502,10 +526,13 @@ void TreeTimesyncBeamSearch::resetStatistics() {
502526 scoringTime_.reset ();
503527 contextExtensionTime_.reset ();
504528 numHypsAfterScorePruning_.clear ();
529+ numHypsAfterRecombination_.clear ();
505530 numHypsAfterBeamPruning_.clear ();
506531 numWordEndHypsAfterScorePruning_.clear ();
532+ numWordEndHypsAfterRecombination_.clear ();
507533 numWordEndHypsAfterBeamPruning_.clear ();
508534 numActiveHyps_.clear ();
535+ numActiveTrees_.clear ();
509536}
510537
511538void TreeTimesyncBeamSearch::logStatistics () const {
@@ -516,10 +543,13 @@ void TreeTimesyncBeamSearch::logStatistics() const {
516543 clog () << Core::XmlOpen (" context-extension-time" ) << contextExtensionTime_.elapsedMilliseconds () << Core::XmlClose (" context-extension-time" );
517544 clog () << Core::XmlClose (" timing-statistics" );
518545 numHypsAfterScorePruning_.write (clog ());
546+ numHypsAfterRecombination_.write (clog ());
519547 numHypsAfterBeamPruning_.write (clog ());
520548 numWordEndHypsAfterScorePruning_.write (clog ());
549+ numWordEndHypsAfterRecombination_.write (clog ());
521550 numWordEndHypsAfterBeamPruning_.write (clog ());
522551 numActiveHyps_.write (clog ());
552+ numActiveTrees_.write (clog ());
523553}
524554
525555Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType (Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const {
@@ -556,14 +586,14 @@ Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn::
556586 }
557587}
558588
559- void TreeTimesyncBeamSearch::beamSizePruning (std::vector<TreeTimesyncBeamSearch::ExtensionCandidate >& extensions , size_t maxBeamSize) const {
560- if (extensions .size () <= maxBeamSize) {
589+ void TreeTimesyncBeamSearch::beamSizePruning (std::vector<LabelHypothesis >& hypotheses , size_t maxBeamSize) const {
590+ if (hypotheses .size () <= maxBeamSize) {
561591 return ;
562592 }
563593
564594 // Sort the hypotheses by associated score value such that the first `maxBeamSize` elements are the best
565- std::nth_element (extensions .begin (), extensions .begin () + maxBeamSize, extensions .end ());
566- extensions .resize (maxBeamSize); // Get rid of excessive elements
595+ std::nth_element (hypotheses .begin (), hypotheses .begin () + maxBeamSize, hypotheses .end ());
596+ hypotheses .resize (maxBeamSize); // Get rid of excessive elements
567597}
568598
569599void TreeTimesyncBeamSearch::scorePruning (std::vector<TreeTimesyncBeamSearch::ExtensionCandidate>& extensions, Score scoreThreshold) const {
0 commit comments