Skip to content

Commit db96b2f

Browse files
authored
Recombination before beam-pruning (#135)
1 parent 5114fe5 commit db96b2f

File tree

4 files changed

+126
-91
lines changed

4 files changed

+126
-91
lines changed

src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration
142142
scoringTime_(),
143143
contextExtensionTime_(),
144144
numHypsAfterScorePruning_("num-hyps-after-score-pruning"),
145+
numHypsAfterRecombination_("num-hyps-after-recombination"),
145146
numHypsAfterBeamPruning_("num-hyps-after-beam-pruning"),
146147
numActiveHyps_("num-active-hyps"),
147148
currentSearchStep_(0ul),
@@ -318,17 +319,8 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() {
318319
}
319320
}
320321

321-
beamSizePruning(extensions_);
322-
numHypsAfterBeamPruning_ += extensions_.size();
323-
if (logStepwiseStatistics_) {
324-
clog() << Core::XmlFull("num-hyps-after-beam-pruning", extensions_.size());
325-
}
326-
327-
/*
328-
* Create new beam from surviving extensions.
329-
*/
322+
// Create new beam from surviving extensions.
330323
newBeam_.clear();
331-
332324
for (auto const& extension : extensions_) {
333325
auto const& baseHyp = beam_[extension.baseHypIndex];
334326

@@ -340,11 +332,19 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() {
340332
newBeam_.push_back({baseHyp, extension, newScoringContext});
341333
}
342334

343-
/*
344-
* For all hypotheses with the same scoring context keep only the best since they will
345-
* all develop in the same way.
346-
*/
335+
// For all hypotheses with the same scoring context keep only the best since they will all develop in the same way.
347336
recombination(newBeam_);
337+
numHypsAfterRecombination_ += newBeam_.size();
338+
if (logStepwiseStatistics_) {
339+
clog() << Core::XmlFull("num-hyps-after-recombination", newBeam_.size());
340+
}
341+
342+
beamSizePruning(newBeam_);
343+
numHypsAfterBeamPruning_ += newBeam_.size();
344+
if (logStepwiseStatistics_) {
345+
clog() << Core::XmlFull("num-hyps-after-beam-pruning", newBeam_.size());
346+
}
347+
348348
numActiveHyps_ += newBeam_.size();
349349

350350
/*
@@ -400,6 +400,7 @@ void LexiconfreeTimesyncBeamSearch::resetStatistics() {
400400
scoringTime_.reset();
401401
contextExtensionTime_.reset();
402402
numHypsAfterScorePruning_.clear();
403+
numHypsAfterRecombination_.clear();
403404
numHypsAfterBeamPruning_.clear();
404405
numActiveHyps_.clear();
405406
}
@@ -412,6 +413,7 @@ void LexiconfreeTimesyncBeamSearch::logStatistics() const {
412413
clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.elapsedMilliseconds() << Core::XmlClose("context-extension-time");
413414
clog() << Core::XmlClose("timing-statistics");
414415
numHypsAfterScorePruning_.write(clog());
416+
numHypsAfterRecombination_.write(clog());
415417
numHypsAfterBeamPruning_.write(clog());
416418
numActiveHyps_.write(clog());
417419
}
@@ -450,14 +452,14 @@ Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionTy
450452
}
451453
}
452454

453-
void LexiconfreeTimesyncBeamSearch::beamSizePruning(std::vector<LexiconfreeTimesyncBeamSearch::ExtensionCandidate>& extensions) const {
454-
if (extensions.size() <= maxBeamSize_) {
455+
void LexiconfreeTimesyncBeamSearch::beamSizePruning(std::vector<LabelHypothesis>& hypotheses) const {
456+
if (hypotheses.size() <= maxBeamSize_) {
455457
return;
456458
}
457459

458460
// Reorder the hypotheses by associated score value such that the first `beamSize_` elements are the best
459-
std::nth_element(extensions.begin(), extensions.begin() + maxBeamSize_, extensions.end());
460-
extensions.resize(maxBeamSize_); // Get rid of excessive elements
461+
std::nth_element(hypotheses.begin(), hypotheses.begin() + maxBeamSize_, hypotheses.end());
462+
hypotheses.resize(maxBeamSize_); // Get rid of excessive elements
461463
}
462464

463465
void LexiconfreeTimesyncBeamSearch::scorePruning(std::vector<LexiconfreeTimesyncBeamSearch::ExtensionCandidate>& extensions) const {

src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ private:
134134
Core::StopWatch contextExtensionTime_;
135135

136136
Core::Statistics<u32> numHypsAfterScorePruning_;
137+
Core::Statistics<u32> numHypsAfterRecombination_;
137138
Core::Statistics<u32> numHypsAfterBeamPruning_;
138139
Core::Statistics<u32> numActiveHyps_;
139140

@@ -155,7 +156,7 @@ private:
155156
/*
156157
* Helper function for pruning to maxBeamSize_
157158
*/
158-
void beamSizePruning(std::vector<LexiconfreeTimesyncBeamSearch::ExtensionCandidate>& extensions) const;
159+
void beamSizePruning(std::vector<LabelHypothesis>& hypotheses) const;
159160

160161
/*
161162
* Helper function for pruning to scoreThreshold_

src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc

Lines changed: 99 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,9 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config
145145
labelScorer_(),
146146
debugChannel_(config, "debug"),
147147
extensions_(),
148-
withinWordExtensions_(),
149-
wordEndExtensions_(),
150148
beam_(),
151149
newBeam_(),
150+
wordEndHypotheses_(),
152151
requests_(),
153152
recombinedHypotheses_(),
154153
currentSearchStep_(0ul),
@@ -158,10 +157,13 @@ TreeTimesyncBeamSearch::TreeTimesyncBeamSearch(Core::Configuration const& config
158157
scoringTime_(),
159158
contextExtensionTime_(),
160159
numHypsAfterScorePruning_("num-hyps-after-score-pruning"),
160+
numHypsAfterRecombination_("num-hyps-after-recombination"),
161161
numHypsAfterBeamPruning_("num-hyps-after-beam-pruning"),
162162
numWordEndHypsAfterScorePruning_("num-word-end-hyps-after-score-pruning"),
163+
numWordEndHypsAfterRecombination_("num-word-end-hyps-after-recombination"),
163164
numWordEndHypsAfterBeamPruning_("num-word-end-hyps-after-beam-pruning"),
164-
numActiveHyps_("num-active-hyps") {
165+
numActiveHyps_("num-active-hyps"),
166+
numActiveTrees_("num-active-trees") {
165167
if (scoreThreshold_ == Core::Type<Score>::max and wordEndScoreThreshold_ != Core::Type<Score>::max) {
166168
error() << "Word-end score-threshold which is relative to the score-threshold is set, but score-threshold is not set";
167169
}
@@ -359,111 +361,133 @@ bool TreeTimesyncBeamSearch::decodeStep() {
359361
clog() << Core::XmlFull("num-hyps-after-score-pruning", extensions_.size());
360362
}
361363

362-
beamSizePruning(extensions_, maxBeamSize_);
363-
numHypsAfterBeamPruning_ += extensions_.size();
364+
// Create new label hypotheses from extension candidates
365+
newBeam_.clear();
366+
for (auto const& extension : extensions_) {
367+
auto const& baseHyp = beam_[extension.baseHypIndex];
368+
369+
auto newScoringContext = labelScorer_->extendedScoringContext(
370+
{baseHyp.scoringContext,
371+
extension.nextToken,
372+
extension.transitionType});
373+
374+
newBeam_.push_back({baseHyp, extension, newScoringContext});
375+
}
376+
377+
// For all hypotheses at the same state and with the same scoring context and LM history
378+
// keep only the best since they will all develop in the same way
379+
recombination(newBeam_);
380+
numHypsAfterRecombination_ += newBeam_.size();
381+
if (logStepwiseStatistics_) {
382+
clog() << Core::XmlFull("num-hyps-after-recombination", newBeam_.size());
383+
}
384+
385+
beamSizePruning(newBeam_, maxBeamSize_);
386+
numHypsAfterBeamPruning_ += newBeam_.size();
364387
if (logStepwiseStatistics_) {
365-
clog() << Core::XmlFull("num-hyps-after-beam-pruning", extensions_.size());
388+
clog() << Core::XmlFull("num-hyps-after-beam-pruning", newBeam_.size());
366389
}
367390

368391
/*
369-
* Expand extensions to word-end hypotheses and incorporate the language model
392+
* Expand hypotheses to word-end hypotheses and incorporate the language model
370393
*/
371-
withinWordExtensions_.clear();
372-
wordEndExtensions_.clear();
373-
for (const auto& extension : extensions_) {
374-
// If there is at least one state successor, keep it as within-word hypothesis
375-
if (not stateSuccessorLookup_[extension.state].empty()) {
376-
withinWordExtensions_.push_back(extension);
377-
}
378-
std::vector<PersistentStateTree::Exit> exitList = exitLookup_[extension.state];
394+
extensions_.clear();
395+
for (size_t hypIndex = 0ul; hypIndex < newBeam_.size(); ++hypIndex) {
396+
auto& hyp = newBeam_[hypIndex];
397+
398+
std::vector<PersistentStateTree::Exit> exitList = exitLookup_[hyp.currentState];
379399
if (not exitList.empty()) {
380400
// Create one word-end hypothesis for each exit
381401
for (const auto& exit : exitList) {
382-
ExtensionCandidate wordEndExtension(extension);
383402
const Bliss::LemmaPronunciation* lemmaPron = lexicon_->lemmaPronunciation(exit.pronunciation);
384403
const Bliss::Lemma* lemma = lemmaPron->lemma();
385404

386-
// Start from the root node (the exit's transit state) in the next step
387-
wordEndExtension.state = exit.transitState;
388-
wordEndExtension.pron = lemmaPron;
405+
ExtensionCandidate wordEndExtension{hyp.currentToken,
406+
lemmaPron,
407+
exit.transitState, // Start from the root node (the exit's transit state) in the next step
408+
hyp.lmHistory,
409+
hyp.score,
410+
0.0,
411+
static_cast<TimeframeIndex>(currentSearchStep_),
412+
Nn::LabelScorer::TransitionType::INITIAL_BLANK, // The transition type is irrelevant, so just use this as dummy
413+
hypIndex};
389414

390415
if (lemma != lexicon_->specialLemma("blank") and lemma != lexicon_->specialLemma("silence")) {
391416
const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence();
392417
const Bliss::SyntacticToken* st = sts.front();
393418

394-
// Add the LM score and update the LM history
419+
// Add the LM score
395420
Lm::Score lmScore = languageModel_->score(wordEndExtension.lmHistory, st);
396421
wordEndExtension.score += lmScore;
397-
wordEndExtension.lmScore = lmScore;
398-
wordEndExtension.lmHistory = languageModel_->extendedHistory(wordEndExtension.lmHistory, st);
422+
wordEndExtension.lmScore = lmScore;
399423
}
400-
wordEndExtensions_.push_back(wordEndExtension);
424+
extensions_.push_back(wordEndExtension);
401425
}
402426
}
403427
}
404428

405429
/*
406-
* Prune set of word-end hypotheses by max beam size and possibly also by score.
430+
* Prune set of word-end extensions by max beam size and possibly also by score.
407431
*/
408-
scorePruning(wordEndExtensions_, wordEndScoreThreshold_);
409-
numWordEndHypsAfterScorePruning_ += wordEndExtensions_.size();
432+
scorePruning(extensions_, wordEndScoreThreshold_);
433+
numWordEndHypsAfterScorePruning_ += extensions_.size();
410434
if (logStepwiseStatistics_) {
411-
clog() << Core::XmlFull("num-word-end-hyps-after-score-pruning", wordEndExtensions_.size());
435+
clog() << Core::XmlFull("num-word-end-hyps-after-score-pruning", extensions_.size());
412436
}
413437

414-
beamSizePruning(wordEndExtensions_, maxWordEndBeamSize_);
415-
numWordEndHypsAfterBeamPruning_ += wordEndExtensions_.size();
416-
if (logStepwiseStatistics_) {
417-
clog() << Core::XmlFull("num-word-end-hyps-after-beam-pruning", wordEndExtensions_.size());
418-
}
419-
420-
/*
421-
* Create new beam from surviving extensions.
422-
*/
423-
newBeam_.clear();
424-
for (auto const& extension : withinWordExtensions_) {
425-
auto const& baseHyp = beam_[extension.baseHypIndex];
426-
427-
auto newScoringContext = labelScorer_->extendedScoringContext(
428-
{baseHyp.scoringContext,
429-
extension.nextToken,
430-
extension.transitionType});
438+
// Create new word-end label hypotheses from word-end extension candidates and update the LM history
439+
wordEndHypotheses_.clear();
440+
for (auto& extension : extensions_) {
441+
const Bliss::Lemma* lemma = extension.pron->lemma();
442+
if (lemma != lexicon_->specialLemma("blank") and lemma != lexicon_->specialLemma("silence")) {
443+
const Bliss::SyntacticTokenSequence sts = lemma->syntacticTokenSequence();
444+
const Bliss::SyntacticToken* st = sts.front();
445+
extension.lmHistory = languageModel_->extendedHistory(extension.lmHistory, st);
446+
}
431447

432-
newBeam_.push_back({baseHyp, extension, newScoringContext});
448+
auto const& baseHyp = newBeam_[extension.baseHypIndex];
449+
wordEndHypotheses_.push_back({baseHyp, extension, baseHyp.scoringContext});
433450
}
434-
for (auto const& extension : wordEndExtensions_) {
435-
auto const& baseHyp = beam_[extension.baseHypIndex];
436451

437-
auto newScoringContext = labelScorer_->extendedScoringContext(
438-
{baseHyp.scoringContext,
439-
extension.nextToken,
440-
extension.transitionType});
452+
recombination(wordEndHypotheses_);
453+
numWordEndHypsAfterRecombination_ += wordEndHypotheses_.size();
454+
if (logStepwiseStatistics_) {
455+
clog() << Core::XmlFull("num-word-end-hyps-after-recombination", wordEndHypotheses_.size());
456+
}
441457

442-
newBeam_.push_back({baseHyp, extension, newScoringContext});
458+
beamSizePruning(wordEndHypotheses_, maxWordEndBeamSize_);
459+
numWordEndHypsAfterBeamPruning_ += wordEndHypotheses_.size();
460+
if (logStepwiseStatistics_) {
461+
clog() << Core::XmlFull("num-word-end-hyps-after-beam-pruning", wordEndHypotheses_.size());
443462
}
444463

445-
/*
446-
* For all hypotheses at the same state and with the same scoring context and LM history
447-
* keep only the best since they will all develop in the same way.
448-
*/
449-
recombination(newBeam_);
450-
numActiveHyps_ += newBeam_.size();
464+
beam_.swap(newBeam_);
465+
beam_.insert(beam_.end(), wordEndHypotheses_.begin(), wordEndHypotheses_.end());
466+
467+
numActiveHyps_ += beam_.size();
451468

452469
/*
453-
* Clean up label scorer caches.
470+
* Clean up label scorer caches and calculate number of active trees
454471
*/
455-
if (++currentSearchStep_ % cacheCleanupInterval_ == 0) {
456-
Core::CollapsedVector<Nn::ScoringContextRef> activeContexts;
457-
for (auto const& hyp : newBeam_) {
458-
activeContexts.push_back(hyp.scoringContext);
472+
Core::CollapsedVector<Nn::ScoringContextRef> activeContexts;
473+
std::vector<Lm::History> seenHistories;
474+
for (auto const& hyp : beam_) {
475+
activeContexts.push_back(hyp.scoringContext);
476+
if (std::find(seenHistories.begin(), seenHistories.end(), hyp.lmHistory) == seenHistories.end()) {
477+
seenHistories.push_back(hyp.lmHistory);
459478
}
479+
}
480+
if (++currentSearchStep_ % cacheCleanupInterval_ == 0) {
460481
labelScorer_->cleanupCaches(activeContexts);
461482
}
483+
numActiveTrees_ += seenHistories.size();
484+
if (logStepwiseStatistics_) {
485+
clog() << Core::XmlFull("num-active-trees", seenHistories.size());
486+
}
462487

463488
/*
464-
* Log statistics about the new beam after this step.
489+
* Log statistics about the new beam.
465490
*/
466-
beam_.swap(newBeam_);
467491

468492
if (debugChannel_.isOpen()) {
469493
std::stringstream ss;
@@ -502,10 +526,13 @@ void TreeTimesyncBeamSearch::resetStatistics() {
502526
scoringTime_.reset();
503527
contextExtensionTime_.reset();
504528
numHypsAfterScorePruning_.clear();
529+
numHypsAfterRecombination_.clear();
505530
numHypsAfterBeamPruning_.clear();
506531
numWordEndHypsAfterScorePruning_.clear();
532+
numWordEndHypsAfterRecombination_.clear();
507533
numWordEndHypsAfterBeamPruning_.clear();
508534
numActiveHyps_.clear();
535+
numActiveTrees_.clear();
509536
}
510537

511538
void TreeTimesyncBeamSearch::logStatistics() const {
@@ -516,10 +543,13 @@ void TreeTimesyncBeamSearch::logStatistics() const {
516543
clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.elapsedMilliseconds() << Core::XmlClose("context-extension-time");
517544
clog() << Core::XmlClose("timing-statistics");
518545
numHypsAfterScorePruning_.write(clog());
546+
numHypsAfterRecombination_.write(clog());
519547
numHypsAfterBeamPruning_.write(clog());
520548
numWordEndHypsAfterScorePruning_.write(clog());
549+
numWordEndHypsAfterRecombination_.write(clog());
521550
numWordEndHypsAfterBeamPruning_.write(clog());
522551
numActiveHyps_.write(clog());
552+
numActiveTrees_.write(clog());
523553
}
524554

525555
Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const {
@@ -556,14 +586,14 @@ Nn::LabelScorer::TransitionType TreeTimesyncBeamSearch::inferTransitionType(Nn::
556586
}
557587
}
558588

559-
void TreeTimesyncBeamSearch::beamSizePruning(std::vector<TreeTimesyncBeamSearch::ExtensionCandidate>& extensions, size_t maxBeamSize) const {
560-
if (extensions.size() <= maxBeamSize) {
589+
void TreeTimesyncBeamSearch::beamSizePruning(std::vector<LabelHypothesis>& hypotheses, size_t maxBeamSize) const {
590+
if (hypotheses.size() <= maxBeamSize) {
561591
return;
562592
}
563593

564594
// Sort the hypotheses by associated score value such that the first `maxBeamSize` elements are the best
565-
std::nth_element(extensions.begin(), extensions.begin() + maxBeamSize, extensions.end());
566-
extensions.resize(maxBeamSize); // Get rid of excessive elements
595+
std::nth_element(hypotheses.begin(), hypotheses.begin() + maxBeamSize, hypotheses.end());
596+
hypotheses.resize(maxBeamSize); // Get rid of excessive elements
567597
}
568598

569599
void TreeTimesyncBeamSearch::scorePruning(std::vector<TreeTimesyncBeamSearch::ExtensionCandidate>& extensions, Score scoreThreshold) const {

0 commit comments

Comments
 (0)