Skip to content

Commit 5cfe10a

Browse files
larissaklSimBe195curufinwe
authored
Update Nn::TransitionLabelScorer: remove wrapped LabelScorer (#155)
Co-authored-by: Simon Berger <[email protected]> Co-authored-by: Eugen Beck <[email protected]>
1 parent b4f3bc9 commit 5cfe10a

File tree

2 files changed

+22
-47
lines changed

2 files changed

+22
-47
lines changed

src/Nn/LabelScorer/TransitionLabelScorer.cc

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,31 @@ namespace Nn {
2222
TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config)
2323
: Core::Component(config),
2424
Precursor(config, TransitionPresetType::ALL),
25-
transitionScores_(),
26-
baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) {
25+
transitionScores_() {
2726
for (auto const& [stringIdentifier, enumValue] : transitionTypeArray_) {
2827
auto paramName = std::string(stringIdentifier) + "-score";
2928
transitionScores_[enumValue] = Core::ParameterFloat(paramName.c_str(), "", 0.0)(config);
3029
}
3130
}
3231

33-
void TransitionLabelScorer::reset() {
34-
baseLabelScorer_->reset();
35-
}
32+
void TransitionLabelScorer::reset() {}
3633

37-
void TransitionLabelScorer::signalNoMoreFeatures() {
38-
baseLabelScorer_->signalNoMoreFeatures();
39-
}
34+
void TransitionLabelScorer::signalNoMoreFeatures() {}
4035

4136
ScoringContextRef TransitionLabelScorer::getInitialScoringContext() {
42-
return baseLabelScorer_->getInitialScoringContext();
43-
}
44-
45-
void TransitionLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
46-
baseLabelScorer_->cleanupCaches(activeContexts);
37+
return Core::ref(new ScoringContext());
4738
}
4839

49-
void TransitionLabelScorer::addInput(DataView const& input) {
50-
baseLabelScorer_->addInput(input);
51-
}
52-
53-
void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
54-
baseLabelScorer_->addInputs(input, nTimesteps);
55-
}
40+
void TransitionLabelScorer::addInput(DataView const& input) {}
5641

5742
ScoringContextRef TransitionLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) {
58-
return baseLabelScorer_->extendedScoringContext(request);
43+
return Core::ref(new ScoringContext());
5944
}
6045

6146
std::optional<LabelScorer::ScoreWithTime> TransitionLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
62-
auto result = baseLabelScorer_->computeScoreWithTime(request);
63-
if (result) {
64-
result->score += transitionScores_[request.transitionType];
65-
}
47+
LabelScorer::ScoreWithTime result;
48+
result.score = transitionScores_[request.transitionType];
49+
result.timeframe = static_cast<Speech::TimeframeIndex>(0);
6650
return result;
6751
}
6852

@@ -71,11 +55,10 @@ std::optional<LabelScorer::ScoresWithTimes> TransitionLabelScorer::computeScores
7155
return ScoresWithTimes{};
7256
}
7357

74-
auto results = baseLabelScorer_->computeScoresWithTimes(requests);
75-
if (results) {
76-
for (size_t i = 0ul; i < requests.size(); ++i) {
77-
results->scores[i] += transitionScores_[requests[i].transitionType];
78-
}
58+
LabelScorer::ScoresWithTimes results;
59+
for (size_t i = 0ul; i < requests.size(); ++i) {
60+
results.scores.push_back(transitionScores_[requests[i].transitionType]);
61+
results.timeframes.push_back(static_cast<Speech::TimeframeIndex>(0));
7962
}
8063
return results;
8164
}

src/Nn/LabelScorer/TransitionLabelScorer.hh

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
namespace Nn {
2222

2323
/*
24-
* This LabelScorer wraps a base LabelScorer and adds predefined transition scores
25-
* to the base scores depending on the transition type of each request.
24+
* This LabelScorer returns predefined transition scores depending on the transition type of each request.
2625
* The transition scores are all individually specified as config parameters.
26+
* It should be used together with a main LabelScorer within the CombineLabelScorer
2727
*/
2828
class TransitionLabelScorer : public LabelScorer {
2929
public:
@@ -32,38 +32,30 @@ public:
3232
TransitionLabelScorer(Core::Configuration const& config);
3333
virtual ~TransitionLabelScorer() = default;
3434

35-
// Reset base scorer
35+
// No op
3636
void reset() override;
3737

38-
// Forward signal to base scorer
38+
// No op
3939
void signalNoMoreFeatures() override;
4040

41-
// Initial context of base scorer
41+
// Return dummy-context
4242
ScoringContextRef getInitialScoringContext() override;
4343

44-
// Clean up base scorer
45-
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
46-
47-
// Add input to base scorer
44+
// No op
4845
void addInput(DataView const& input) override;
4946

50-
// Add inputs to sub-scorer
51-
void addInputs(DataView const& input, size_t nTimesteps) override;
52-
5347
protected:
54-
// Extend context via base scorer
48+
// Return dummy-context
5549
ScoringContextRef extendedScoringContextInternal(Request const& request) override;
5650

57-
// Compute score of base scorer and add transition score based on transition type of the request
51+
// Return transition score based on transition type of the request
5852
std::optional<ScoreWithTime> computeScoreWithTimeInternal(Request const& request) override;
5953

60-
// Compute scores of base scorer and add transition scores based on transition types of the requests
54+
// Return transition scores based on transition types of the requests
6155
std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests) override;
6256

6357
private:
6458
std::unordered_map<TransitionType, Score> transitionScores_;
65-
66-
Core::Ref<LabelScorer> baseLabelScorer_;
6759
};
6860

6961
} // namespace Nn

0 commit comments

Comments
 (0)