Skip to content
Draft
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7502b82
Add TransitionLabelScorer
SimBe195 Jul 24, 2025
7430001
Rewrite docstring
SimBe195 Jul 24, 2025
2a6272e
Clean up includes
SimBe195 Jul 24, 2025
7e325e1
Rewrite docstring again
SimBe195 Jul 24, 2025
a276136
Merge branch 'master' into tdp_label_scorer
SimBe195 Sep 24, 2025
d2d78fe
Refactor params to string list with compile time check
SimBe195 Sep 24, 2025
303fa46
Remove transitionTypeToIndex function and revert associated changes
SimBe195 Sep 24, 2025
ddd75c7
Revert unnecessary static_cast
SimBe195 Sep 24, 2025
b856c1e
Change std=c++17 to c++20
SimBe195 Sep 30, 2025
70699c0
Merge remote-tracking branch 'origin/version-bump' into tdp_label_scorer
SimBe195 Sep 30, 2025
5b89d0f
Move transition type string array to LabelScorer.hh
SimBe195 Sep 30, 2025
b9d919b
Move transitionTypeArray to protected space
SimBe195 Oct 1, 2025
54bee17
Add sentence-end transition to enum
SimBe195 Oct 6, 2025
3dc887b
Sentence-end handling for lexiconfree-search
SimBe195 Oct 6, 2025
b1ba86a
Add finalStates collection to PersistentStateTree
SimBe195 Oct 6, 2025
667f558
Sentence-end handling for tree-search
SimBe195 Oct 8, 2025
1795685
Merge branch 'master' into tdp_label_scorer
SimBe195 Oct 8, 2025
98b824f
Merge branch 'tdp_label_scorer' into sentence_end_handling
SimBe195 Oct 8, 2025
dfdcfe7
Allow no pronunciations of sentence-end
SimBe195 Oct 8, 2025
aa099a4
Add sentence-end-index as member and to inferTransitionType in TreeTi…
SimBe195 Oct 8, 2025
a04c2f9
Change log to warning when sentence-end is not included in tree
SimBe195 Oct 8, 2025
1ee5547
Merge branch 'master' into sentence_end_handling
SimBe195 Oct 8, 2025
d4f6202
Suggestions from code review
SimBe195 Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScore
case TransitionType::BLANK_TO_LABEL:
case TransitionType::LABEL_TO_LABEL:
case TransitionType::INITIAL_LABEL:
case TransitionType::SENTENCE_END:
pushToken = true;
timeIncrement = not verticalLabelTransition_;
break;
Expand Down
2 changes: 2 additions & 0 deletions src/Nn/LabelScorer/LabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public:
BLANK_LOOP,
INITIAL_LABEL,
INITIAL_BLANK,
SENTENCE_END,
numTypes, // must remain at the end
};

Expand Down Expand Up @@ -153,6 +154,7 @@ protected:
{"blank-loop", BLANK_LOOP},
{"initial-label", INITIAL_LABEL},
{"initial-blank", INITIAL_BLANK},
{"sentence-end", SENTENCE_END},
});
static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values");
};
Expand Down
1 change: 1 addition & 0 deletions src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ Core::Ref<const ScoringContext> StatefulOnnxLabelScorer::extendedScoringContext(
case LabelScorer::TransitionType::BLANK_TO_LABEL:
case LabelScorer::TransitionType::LABEL_TO_LABEL:
case LabelScorer::TransitionType::INITIAL_LABEL:
case LabelScorer::TransitionType::SENTENCE_END:
updateState = true;
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis()
: scoringContext(),
currentToken(Nn::invalidLabelIndex),
score(0.0),
trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {}
trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))),
reachedSentenceEnd(false) {}

LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis(
LexiconfreeTimesyncBeamSearch::LabelHypothesis const& base,
Expand All @@ -46,19 +47,22 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis(
: scoringContext(newScoringContext),
currentToken(extension.nextToken),
score(extension.score),
trace() {
trace(),
reachedSentenceEnd(base.reachedSentenceEnd) {
switch (extension.transitionType) {
case Nn::LabelScorer::INITIAL_BLANK:
case Nn::LabelScorer::INITIAL_LABEL:
case Nn::LabelScorer::LABEL_TO_LABEL:
case Nn::LabelScorer::LABEL_TO_BLANK:
case Nn::LabelScorer::BLANK_TO_LABEL:
trace = Core::ref(new LatticeTrace(
case Nn::LabelScorer::SENTENCE_END:
trace = Core::ref(new LatticeTrace(
base.trace,
extension.pron,
extension.timeframe + 1,
{extension.score, 0},
{}));
reachedSentenceEnd = true;
break;
case Nn::LabelScorer::LABEL_LOOP:
case Nn::LabelScorer::BLANK_LOOP:
Expand All @@ -68,6 +72,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis(
trace->score.acoustic = extension.score;
trace->time = extension.timeframe + 1;
break;
default:
break;
}
}

Expand Down Expand Up @@ -106,6 +112,21 @@ const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramBlankLabelIndex(
"Index of the blank label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='blank'`. If not set, the search will not use blank.",
Nn::invalidLabelIndex);

const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramSentenceEndLabelIndex(
"sentence-end-label-index",
"Index of the sentence end label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='sentence-end'` or `special='sentence-boundary'`. If not set, the search will not use sentence end.",
Nn::invalidLabelIndex);

const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramAllowBlankAfterSentenceEnd(
"allow-blank-after-sentence-end",
"blanks can still be produced after the sentence-end has been reached",
true);

const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramSentenceEndFallBack(
"sentence-end-fall-back",
"Allow for fallback solution if no active word-end hypothesis exists at the end of a segment.",
true);

const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramCollapseRepeatedLabels(
"collapse-repeated-labels",
"Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.",
Expand All @@ -126,7 +147,11 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration
SearchAlgorithmV2(config),
maxBeamSize_(paramMaxBeamSize(config)),
scoreThreshold_(paramScoreThreshold(config)),
sentenceEndFallback_(paramSentenceEndFallBack(config)),
blankLabelIndex_(paramBlankLabelIndex(config)),
allowBlankAfterSentenceEnd_(paramAllowBlankAfterSentenceEnd(config)),
sentenceEndLemma_(),
sentenceEndLabelIndex_(paramSentenceEndLabelIndex(config)),
collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)),
logStepwiseStatistics_(paramLogStepwiseStatistics(config)),
cacheCleanupInterval_(paramCacheCleanupInterval(config)),
Expand Down Expand Up @@ -154,6 +179,12 @@ LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration
if (useBlank_) {
log() << "Use blank label with index " << blankLabelIndex_;
}

useSentenceEnd_ = sentenceEndLabelIndex_ != Nn::invalidLabelIndex;
if (useSentenceEnd_) {
log() << "Use sentence end label with index " << sentenceEndLabelIndex_;
}

useScorePruning_ = scoreThreshold_ != Core::Type<Score>::max;
}

Expand All @@ -180,6 +211,21 @@ bool LexiconfreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination
}
}

sentenceEndLemma_ = lexicon_->specialLemma("sentence-end");
if (!sentenceEndLemma_) {
sentenceEndLemma_ = lexicon_->specialLemma("sentence-boundary");
}
if (sentenceEndLemma_) {
if (sentenceEndLabelIndex_ == Core::Type<s32>::max) {
sentenceEndLabelIndex_ = sentenceEndLemma_->id();
useSentenceEnd_ = true;
log() << "Use sentence-end index " << sentenceEndLabelIndex_ << " inferred from lexicon";
}
else if (sentenceEndLabelIndex_ != static_cast<Nn::LabelIndex>(sentenceEndLemma_->id())) {
warning() << "SentenceEnd lemma exists in lexicon with id " << sentenceEndLemma_->id() << " but is overwritten by config parameter with value " << sentenceEndLabelIndex_;
}
}

reset();
return true;
}
Expand Down Expand Up @@ -214,6 +260,7 @@ void LexiconfreeTimesyncBeamSearch::finishSegment() {
labelScorer_->signalNoMoreFeatures();
featureProcessingTime_.stop();
decodeManySteps();
finalizeHypotheses();
logStatistics();
finishedSegment_ = true;
}
Expand Down Expand Up @@ -271,6 +318,14 @@ bool LexiconfreeTimesyncBeamSearch::decodeStep() {
const Bliss::Lemma* lemma(*lemmaIt);
Nn::LabelIndex tokenIdx = lemma->id();

// After first sentence-end token only allow looping that sentence-end or blanks afterwards
if (hyp.reachedSentenceEnd and
not(
(collapseRepeatedLabels_ and hyp.currentToken == sentenceEndLabelIndex_ and tokenIdx == sentenceEndLabelIndex_) // sentence-end-loop
or (allowBlankAfterSentenceEnd_ and tokenIdx == blankLabelIndex_))) { // blank
continue;
}

auto transitionType = inferTransitionType(hyp.currentToken, tokenIdx);

extensions_.push_back(
Expand Down Expand Up @@ -419,13 +474,17 @@ void LexiconfreeTimesyncBeamSearch::logStatistics() const {
}

Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const {
bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_);
bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_);
bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_);
bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_);
bool nextIsSentenceEnd = (useSentenceEnd_ and nextLabel == sentenceEndLabelIndex_);

if (prevLabel == Nn::invalidLabelIndex) {
if (nextIsBlank) {
return Nn::LabelScorer::TransitionType::INITIAL_BLANK;
}
else if (nextIsSentenceEnd) {
return Nn::LabelScorer::TransitionType::SENTENCE_END;
}
else {
return Nn::LabelScorer::TransitionType::INITIAL_LABEL;
}
Expand All @@ -435,6 +494,9 @@ Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionTy
if (nextIsBlank) {
return Nn::LabelScorer::TransitionType::BLANK_LOOP;
}
else if (nextIsSentenceEnd) {
return Nn::LabelScorer::TransitionType::SENTENCE_END;
}
else {
return Nn::LabelScorer::TransitionType::BLANK_TO_LABEL;
}
Expand All @@ -446,6 +508,9 @@ Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionTy
else if (collapseRepeatedLabels_ and prevLabel == nextLabel) {
return Nn::LabelScorer::TransitionType::LABEL_LOOP;
}
else if (nextIsSentenceEnd) {
return Nn::LabelScorer::TransitionType::SENTENCE_END;
}
else {
return Nn::LabelScorer::TransitionType::LABEL_TO_LABEL;
}
Expand Down Expand Up @@ -513,4 +578,36 @@ void LexiconfreeTimesyncBeamSearch::recombination(std::vector<LexiconfreeTimesyn
hypotheses.swap(recombinedHypotheses_);
}

void LexiconfreeTimesyncBeamSearch::finalizeHypotheses() {
if (not useSentenceEnd_) {
return;
}

newBeam_.clear();
for (auto const& hyp : beam_) {
if (hyp.reachedSentenceEnd) {
newBeam_.push_back(hyp);
}
}

if (newBeam_.empty()) { // There was no valid final hypothesis in the beam
warning("No hypothesis has produced sentence-end by the end of the segment.");
if (sentenceEndFallback_) {
log() << "Use sentence-end fallback";
// Keep `beam_` as it is
}
else {
newBeam_.push_back(LabelHypothesis());
newBeam_.front().trace->time = beam_.front().trace->time; // Retrieve the timeframe from any hyp in the old beam
newBeam_.front().trace->pronunciation = nullptr;
newBeam_.front().trace->predecessor = Core::ref(new LatticeTrace(0, {0, 0}, {}));
newBeam_.front().reachedSentenceEnd = true;
beam_.swap(newBeam_);
}
}
else {
newBeam_.swap(beam_);
}
}

} // namespace Search
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ public:
static const Core::ParameterInt paramMaxBeamSize;
static const Core::ParameterFloat paramScoreThreshold;
static const Core::ParameterInt paramBlankLabelIndex;
static const Core::ParameterInt paramSentenceEndLabelIndex;
static const Core::ParameterBool paramAllowBlankAfterSentenceEnd;
static const Core::ParameterBool paramSentenceEndFallBack;
static const Core::ParameterBool paramCollapseRepeatedLabels;
static const Core::ParameterBool paramCacheCleanupInterval;
static const Core::ParameterBool paramLogStepwiseStatistics;
Expand Down Expand Up @@ -83,10 +86,11 @@ protected:
* Struct containing all information about a single hypothesis in the beam
*/
struct LabelHypothesis {
Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis
Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type)
Score score; // Full score of hypothesis
Core::Ref<LatticeTrace> trace; // Associated trace for traceback or lattice building off of hypothesis
Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis
Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type)
Score score; // Full score of hypothesis
Core::Ref<LatticeTrace> trace; // Associated trace for traceback or lattice building off of hypothesis
bool reachedSentenceEnd; // Flag whether hypothesis trace contains a sentence end emission

LabelHypothesis();
LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext);
Expand All @@ -107,8 +111,15 @@ private:
bool useScorePruning_;
Score scoreThreshold_;

bool sentenceEndFallback_;

bool useBlank_;
Nn::LabelIndex blankLabelIndex_;
bool allowBlankAfterSentenceEnd_;

bool useSentenceEnd_;
Bliss::Lemma const* sentenceEndLemma_;
Nn::LabelIndex sentenceEndLabelIndex_;

bool collapseRepeatedLabels_;

Expand Down Expand Up @@ -167,6 +178,12 @@ private:
* Helper function for recombination of hypotheses with the same scoring context
*/
void recombination(std::vector<LabelHypothesis>& hypotheses);

/*
* Prune away all hypotheses that have not reached sentence end.
* If no hypotheses would survive this, either construct an empty one or keep the beam intact if sentence-end fallback is enabled.
*/
void finalizeHypotheses();
};

} // namespace Search
Expand Down
11 changes: 7 additions & 4 deletions src/Search/PersistentStateTree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static const Core::ParameterString paramCacheArchive(
"cache archive in which the persistent state-network should be cached",
"global-cache");

static u32 formatVersion = 13;
static u32 formatVersion = 14;

namespace Search {
struct ConvertTree {
Expand Down Expand Up @@ -298,7 +298,7 @@ void PersistentStateTree::write(Core::MappedArchiveWriter out) {

out << coarticulatedRootStates << unpushedCoarticulatedRootStates;
out << rootTransitDescriptions << pushedWordEndNodes << uncoarticulatedWordEndStates;
out << rootState << ciRootState << otherRootStates;
out << rootState << ciRootState << otherRootStates << finalStates;
}

bool PersistentStateTree::read(Core::MappedArchiveReader in) {
Expand All @@ -307,8 +307,8 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) {

/// @todo Eventually do memory-mapping

if (v != formatVersion) {
Core::Application::us()->log() << "Wrong compressed network format, need " << formatVersion << " got " << v;
if (v < 13) {
Core::Application::us()->log() << "Wrong compressed network format, need version >= 13 got " << v;
return false;
}

Expand All @@ -333,6 +333,9 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) {
in >> pushedWordEndNodes >> uncoarticulatedWordEndStates;

in >> rootState >> ciRootState >> otherRootStates;
if (v >= 14) {
in >> finalStates;
}

return in.good();
}
Expand Down
3 changes: 3 additions & 0 deletions src/Search/PersistentStateTree.hh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ public:
// Other root nodes (currently used for the wordBoundaryRoot in CtcTreeBuilder)
std::set<StateId> otherRootStates;

// Valid nodes that the search can end in
std::set<StateId> finalStates;

// The word-end exits
std::vector<Exit> exits;

Expand Down
Loading