Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions src/Nn/LabelScorer/TransitionLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,31 @@ namespace Nn {
TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config)
: Core::Component(config),
Precursor(config),
transitionScores_(),
baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) {
transitionScores_() {
for (auto const& [stringIdentifier, enumValue] : transitionTypeArray_) {
auto paramName = std::string(stringIdentifier) + "-score";
transitionScores_[enumValue] = Core::ParameterFloat(paramName.c_str(), "", 0.0)(config);
}
}

void TransitionLabelScorer::reset() {
baseLabelScorer_->reset();
}
void TransitionLabelScorer::reset() {}

void TransitionLabelScorer::signalNoMoreFeatures() {
baseLabelScorer_->signalNoMoreFeatures();
}
void TransitionLabelScorer::signalNoMoreFeatures() {}

ScoringContextRef TransitionLabelScorer::getInitialScoringContext() {
return baseLabelScorer_->getInitialScoringContext();
}

void TransitionLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
baseLabelScorer_->cleanupCaches(activeContexts);
return Core::ref(new StepScoringContext());
}

void TransitionLabelScorer::addInput(DataView const& input) {
baseLabelScorer_->addInput(input);
}

void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
baseLabelScorer_->addInputs(input, nTimesteps);
}
void TransitionLabelScorer::addInput(DataView const& input) {}

ScoringContextRef TransitionLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) {
return baseLabelScorer_->extendedScoringContext(request);
return Core::ref(new StepScoringContext());
}

std::optional<LabelScorer::ScoreWithTime> TransitionLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
auto result = baseLabelScorer_->computeScoreWithTime(request);
if (result) {
result->score += transitionScores_[request.transitionType];
}
LabelScorer::ScoreWithTime result;
result.score = transitionScores_[request.transitionType];
result.timeframe = static_cast<Speech::TimeframeIndex>(0);
return result;
}

Expand All @@ -71,11 +55,10 @@ std::optional<LabelScorer::ScoresWithTimes> TransitionLabelScorer::computeScores
return ScoresWithTimes{};
}

auto results = baseLabelScorer_->computeScoresWithTimes(requests);
if (results) {
for (size_t i = 0ul; i < requests.size(); ++i) {
results->scores[i] += transitionScores_[requests[i].transitionType];
}
LabelScorer::ScoresWithTimes results;
for (size_t i = 0ul; i < requests.size(); ++i) {
results.scores.push_back(transitionScores_[requests[i].transitionType]);
results.timeframes.push_back(static_cast<Speech::TimeframeIndex>(0));
}
return results;
}
Expand Down
26 changes: 9 additions & 17 deletions src/Nn/LabelScorer/TransitionLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
namespace Nn {

/*
* This LabelScorer wraps a base LabelScorer and adds predefined transition scores
* to the base scores depending on the transition type of each request.
* This LabelScorer returns predefined transition scores depending on the transition type of each request.
* The transition scores are all individually specified as config parameters.
* It should be used together with a main LabelScorer within the CombineLabelScorer
*/
class TransitionLabelScorer : public LabelScorer {
public:
Expand All @@ -32,32 +32,26 @@ public:
TransitionLabelScorer(Core::Configuration const& config);
virtual ~TransitionLabelScorer() = default;

// Reset base scorer
// No op
void reset() override;

// Forward signal to base scorer
// No op
void signalNoMoreFeatures() override;

// Initial context of base scorer
// Return dummy-context
ScoringContextRef getInitialScoringContext() override;

// Clean up base scorer
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;

// Add input to base scorer
// No op
void addInput(DataView const& input) override;

// Add inputs to sub-scorer
void addInputs(DataView const& input, size_t nTimesteps) override;

protected:
// Extend context via base scorer
// Return dummy-context
ScoringContextRef extendedScoringContextInternal(Request const& request) override;

// Compute score of base scorer and add transition score based on transition type of the request
// Return transition score based on transition type of the request
std::optional<ScoreWithTime> computeScoreWithTimeInternal(Request const& request) override;

// Compute scores of base scorer and add transition scores based on transition types of the requests
// Return transition scores based on transition types of the requests
std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests) override;

virtual TransitionPresetType defaultPreset() const override {
Expand All @@ -66,8 +60,6 @@ protected:

private:
std::unordered_map<TransitionType, Score> transitionScores_;

Core::Ref<LabelScorer> baseLabelScorer_;
};

} // namespace Nn
Expand Down