-
Notifications
You must be signed in to change notification settings - Fork 16
Update Nn::LabelScorer: allow enabling of specific transition types #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7502b82
7430001
2a6272e
7e325e1
a276136
d2d78fe
303fa46
ddd75c7
b856c1e
70699c0
5b89d0f
8b27b19
1e877c7
3dcadee
b9d919b
23df463
a437c04
8337fea
0cfdf3d
85522b5
4098ad4
0d34085
f365db2
6e755d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,26 +23,29 @@ namespace Nn { | |
| * ============================= | ||
| */ | ||
|
|
||
| const Core::ParameterStringVector LabelScorer::paramIgnoredTransitionTypes( | ||
| "ignored-transition-types", | ||
| "Transition types that should be ignored by the label scorer (i.e. get assigned score 0 and do not affect the ScoringContext)", | ||
| const Core::Choice LabelScorer::choiceTransitionPreset( | ||
| "default", TransitionPresetType::DEFAULT, | ||
| "none", TransitionPresetType::NONE, | ||
| "ctc", TransitionPresetType::CTC, | ||
| "transducer", TransitionPresetType::TRANSDUCER, | ||
| "lm", TransitionPresetType::LM, | ||
| Core::Choice::endMark()); | ||
|
|
||
| const Core::ParameterChoice LabelScorer::paramTransitionPreset( | ||
| "transition-preset", | ||
| &LabelScorer::choiceTransitionPreset, | ||
| "Preset for which transition types should be enabled for the label scorer. Disabled transition types get assigned score 0 and do not affect the ScoringContext.", | ||
| TransitionPresetType::DEFAULT); | ||
|
|
||
| const Core::ParameterStringVector LabelScorer::paramExtraTransitionTypes( | ||
| "extra-transition-types", | ||
| "Transition types that should be enabled in addition to the ones given by the preset.", | ||
| ","); | ||
|
|
||
| LabelScorer::LabelScorer(const Core::Configuration& config) | ||
| : Core::Component(config), | ||
| ignoredTransitionTypes_() { | ||
| auto ignoredTransitionTypeStrings = paramIgnoredTransitionTypes(config); | ||
| for (auto const& transitionTypeString : ignoredTransitionTypeStrings) { | ||
| auto it = std::find_if(transitionTypeArray_.begin(), | ||
| transitionTypeArray_.end(), | ||
| [&](auto const& entry) { return entry.first == transitionTypeString; }); | ||
| if (it != transitionTypeArray_.end()) { | ||
| ignoredTransitionTypes_.insert(it->second); | ||
| } | ||
| else { | ||
| error() << "Ignored transition type name '" << transitionTypeString << "' is not a valid identifier"; | ||
| } | ||
| } | ||
| enabledTransitionTypes_() { | ||
| enableTransitionTypes(config); | ||
|
||
| } | ||
|
|
||
| void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { | ||
|
|
@@ -53,17 +56,17 @@ void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) { | |
| } | ||
|
|
||
| ScoringContextRef LabelScorer::extendedScoringContext(Request const& request) { | ||
| if (ignoredTransitionTypes_.contains(request.transitionType)) { | ||
| return request.context; | ||
| if (enabledTransitionTypes_.contains(request.transitionType)) { | ||
| return extendedScoringContextInternal(request); | ||
| } | ||
| return extendedScoringContextInternal(request); | ||
| return request.context; | ||
| } | ||
|
|
||
| std::optional<LabelScorer::ScoreWithTime> LabelScorer::computeScoreWithTime(Request const& request) { | ||
| if (ignoredTransitionTypes_.contains(request.transitionType)) { | ||
| return ScoreWithTime{0.0, 0}; | ||
| if (enabledTransitionTypes_.contains(request.transitionType)) { | ||
| return computeScoreWithTimeInternal(request); | ||
| } | ||
| return computeScoreWithTimeInternal(request); | ||
| return ScoreWithTime{0.0, 0}; | ||
| } | ||
|
|
||
| std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) { | ||
|
|
@@ -76,7 +79,7 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes( | |
|
|
||
| for (size_t requestIndex = 0ul; requestIndex < requests.size(); ++requestIndex) { | ||
| auto const& request = requests[requestIndex]; | ||
| if (not ignoredTransitionTypes_.contains(request.transitionType)) { | ||
| if (enabledTransitionTypes_.contains(request.transitionType)) { | ||
| nonIgnoredRequests.push_back(request); | ||
| nonIgnoredRequestIndices.push_back(requestIndex); | ||
| } | ||
|
|
@@ -89,12 +92,13 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes( | |
| } | ||
|
|
||
| // Interleave actual results with 0 scores for requests with ignored transition types | ||
| ScoresWithTimes result{{requests.size(), 0.0}, {requests.size(), 0}}; | ||
| ScoresWithTimes result{ | ||
| .scores = std::vector<Score>(requests.size(), 0.0), | ||
| .timeframes{requests.size(), 0}}; | ||
curufinwe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for (size_t i = 0ul; i < nonIgnoredRequestIndices.size(); ++i) { | ||
| auto requestResult = nonIgnoredResults[i]; | ||
| auto requestIndex = nonIgnoredRequestIndices[i]; | ||
| result.scores[requestIndex] = requestResult.score; | ||
| result.timeframes.set(requestIndex, requestResult.timeframe); | ||
| result.scores[requestIndex] = nonIgnoredResults->scores[i]; | ||
| result.timeframes.set(requestIndex, nonIgnoredResults->timeframes[i]); | ||
| } | ||
|
|
||
| return result; | ||
|
|
@@ -122,4 +126,66 @@ std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimesI | |
| return result; | ||
| } | ||
|
|
||
| void LabelScorer::enableTransitionTypes(Core::Configuration const& config) { | ||
| auto preset = paramTransitionPreset(config); | ||
| if (preset == TransitionPresetType::DEFAULT) { | ||
| preset = defaultPreset(); | ||
| } | ||
| verify(preset != TransitionPresetType::DEFAULT); | ||
|
|
||
| switch (preset) { | ||
| case TransitionPresetType::NONE: | ||
| break; | ||
| case TransitionPresetType::ALL: | ||
| for (auto const& [_, transitionType] : transitionTypeArray_) { | ||
| enabledTransitionTypes_.insert(transitionType); | ||
| } | ||
| break; | ||
| case TransitionPresetType::CTC: | ||
| enabledTransitionTypes_ = { | ||
| LABEL_TO_LABEL, | ||
| LABEL_LOOP, | ||
| LABEL_TO_BLANK, | ||
| BLANK_TO_LABEL, | ||
| BLANK_LOOP, | ||
| INITIAL_LABEL, | ||
| INITIAL_BLANK, | ||
| }; | ||
| break; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's related to this PR #152, but TRANSDUCER and LM need additonally SENTENCE_END transition types. |
||
| case TransitionPresetType::TRANSDUCER: | ||
| enabledTransitionTypes_ = { | ||
| LABEL_TO_LABEL, | ||
| LABEL_TO_BLANK, | ||
| BLANK_TO_LABEL, | ||
| BLANK_LOOP, | ||
| INITIAL_LABEL, | ||
| INITIAL_BLANK, | ||
| }; | ||
| break; | ||
| case TransitionPresetType::LM: | ||
| enabledTransitionTypes_ = { | ||
| LABEL_TO_LABEL, | ||
| INITIAL_LABEL, | ||
|
Comment on lines
+165
to
+168
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The AedTreeBuilder and labelsync searches are still WIP and not merged yet, but at some point we will also need a preset for AED and I guess this will be the same as this one. So will we then just add a preset which is exactly the same, just with a different name? And if yes, should LM or AED be the default preset of the StatefulOnnxLabelScorer? I mean in the end it's the same, but it might become confusing because of the naming. |
||
| }; | ||
| break; | ||
| } | ||
|
|
||
| auto extraTransitionTypeStrings = paramExtraTransitionTypes(config); | ||
| for (auto const& transitionTypeString : extraTransitionTypeStrings) { | ||
| auto it = std::find_if(transitionTypeArray_.begin(), | ||
| transitionTypeArray_.end(), | ||
| [&](auto const& entry) { return entry.first == transitionTypeString; }); | ||
| if (it != transitionTypeArray_.end()) { | ||
| enabledTransitionTypes_.insert(it->second); | ||
| } | ||
| else { | ||
| error() << "Extra transition type name '" << transitionTypeString << "' is not a valid identifier"; | ||
| } | ||
| } | ||
|
|
||
| if (enabledTransitionTypes_.empty()) { | ||
| error() << "Label scorer has no enabled transition types. Activate a preset and/or add extra transition types that should be considered."; | ||
| } | ||
| } | ||
|
|
||
| } // namespace Nn | ||
Uh oh!
There was an error while loading. Please reload this page.