diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc index eccf9c61..2bbca11a 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.cc +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -22,47 +22,31 @@ namespace Nn { TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config) : Core::Component(config), Precursor(config, TransitionPresetType::ALL), - transitionScores_(), - baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) { + transitionScores_() { for (auto const& [stringIdentifier, enumValue] : transitionTypeArray_) { auto paramName = std::string(stringIdentifier) + "-score"; transitionScores_[enumValue] = Core::ParameterFloat(paramName.c_str(), "", 0.0)(config); } } -void TransitionLabelScorer::reset() { - baseLabelScorer_->reset(); -} +void TransitionLabelScorer::reset() {} -void TransitionLabelScorer::signalNoMoreFeatures() { - baseLabelScorer_->signalNoMoreFeatures(); -} +void TransitionLabelScorer::signalNoMoreFeatures() {} ScoringContextRef TransitionLabelScorer::getInitialScoringContext() { - return baseLabelScorer_->getInitialScoringContext(); -} - -void TransitionLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { - baseLabelScorer_->cleanupCaches(activeContexts); + return Core::ref(new ScoringContext()); } -void TransitionLabelScorer::addInput(DataView const& input) { - baseLabelScorer_->addInput(input); -} - -void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) { - baseLabelScorer_->addInputs(input, nTimesteps); -} +void TransitionLabelScorer::addInput(DataView const& input) {} ScoringContextRef TransitionLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) { - return baseLabelScorer_->extendedScoringContext(request); + return Core::ref(new ScoringContext()); } std::optional TransitionLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) { - auto result = baseLabelScorer_->computeScoreWithTime(request); - if (result) { - result->score += transitionScores_[request.transitionType]; - } + LabelScorer::ScoreWithTime result; + result.score = transitionScores_[request.transitionType]; + result.timeframe = static_cast(0); return result; } @@ -71,11 +55,10 @@ std::optional TransitionLabelScorer::computeScores return ScoresWithTimes{}; } - auto results = baseLabelScorer_->computeScoresWithTimes(requests); - if (results) { - for (size_t i = 0ul; i < requests.size(); ++i) { - results->scores[i] += transitionScores_[requests[i].transitionType]; - } + LabelScorer::ScoresWithTimes results; + for (size_t i = 0ul; i < requests.size(); ++i) { + results.scores.push_back(transitionScores_[requests[i].transitionType]); + results.timeframes.push_back(static_cast(0)); } return results; } diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh index b4a3d05e..cf222bc7 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.hh +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -21,9 +21,9 @@ namespace Nn { /* - * This LabelScorer wraps a base LabelScorer and adds predefined transition scores - * to the base scores depending on the transition type of each request. + * This LabelScorer returns predefined transition scores depending on the transition type of each request. * The transition scores are all individually specified as config parameters. + * It should be used together with a main LabelScorer within the CombineLabelScorer */ class TransitionLabelScorer : public LabelScorer { public: @@ -32,38 +32,30 @@ public: TransitionLabelScorer(Core::Configuration const& config); virtual ~TransitionLabelScorer() = default; - // Reset base scorer + // No op void reset() override; - // Forward signal to base scorer + // No op void signalNoMoreFeatures() override; - // Initial context of base scorer + // Return dummy-context ScoringContextRef getInitialScoringContext() override; - // Clean up base scorer - void cleanupCaches(Core::CollapsedVector const& activeContexts) override; - - // Add input to base scorer + // No op void addInput(DataView const& input) override; - // Add inputs to sub-scorer - void addInputs(DataView const& input, size_t nTimesteps) override; - protected: - // Extend context via base scorer + // Return dummy-context ScoringContextRef extendedScoringContextInternal(Request const& request) override; - // Compute score of base scorer and add transition score based on transition type of the request + // Return transition score based on transition type of the request std::optional computeScoreWithTimeInternal(Request const& request) override; - // Compute scores of base scorer and add transition scores based on transition types of the requests + // Return transition scores based on transition types of the requests std::optional computeScoresWithTimesInternal(std::vector const& requests) override; private: std::unordered_map transitionScores_; - - Core::Ref baseLabelScorer_; }; } // namespace Nn