Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7502b82
Add TransitionLabelScorer
SimBe195 Jul 24, 2025
7430001
Rewrite docstring
SimBe195 Jul 24, 2025
2a6272e
Clean up includes
SimBe195 Jul 24, 2025
7e325e1
Rewrite docstring again
SimBe195 Jul 24, 2025
a276136
Merge branch 'master' into tdp_label_scorer
SimBe195 Sep 24, 2025
d2d78fe
Refactor params to string list with compile time check
SimBe195 Sep 24, 2025
303fa46
Remove transitionTypeToIndex function and revert associated changes
SimBe195 Sep 24, 2025
ddd75c7
Revert unnecessary static_cast
SimBe195 Sep 24, 2025
b856c1e
Change std=c++17 to c++20
SimBe195 Sep 30, 2025
70699c0
Merge remote-tracking branch 'origin/version-bump' into tdp_label_scorer
SimBe195 Sep 30, 2025
5b89d0f
Move transition type string array to LabelScorer.hh
SimBe195 Sep 30, 2025
8b27b19
Add parameter for ignoring transition types in LabelScorer
SimBe195 Oct 1, 2025
1e877c7
Add missing parenthesis in description
SimBe195 Oct 1, 2025
3dcadee
Add some docstrings for the `Internal` functions
SimBe195 Oct 1, 2025
b9d919b
Move transitionTypeArray to protected space
SimBe195 Oct 1, 2025
23df463
Merge branch 'tdp_label_scorer' into disabled-transition-types
SimBe195 Oct 1, 2025
a437c04
Merge branch 'master' into disabled-transition-types
curufinwe Oct 8, 2025
8337fea
Fix order in .cc files
curufinwe Oct 9, 2025
0cfdf3d
Add `set` function to Core::CollapsedVector
SimBe195 Oct 10, 2025
85522b5
Apply suggestions from code review
SimBe195 Oct 10, 2025
4098ad4
Fix compilation
SimBe195 Oct 10, 2025
0d34085
Introduce configurable presets of enabled transition types + extras
SimBe195 Oct 10, 2025
f365db2
Handle default transition type via constructor parameter instead of f…
SimBe195 Nov 5, 2025
6e755d8
Merge branch 'master' into disabled-transition-types
SimBe195 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/Nn/LabelScorer/CombineLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ScoringContextRef CombineLabelScorer::getInitialScoringContext() {
return Core::ref(new CombineScoringContext(std::move(scoringContexts)));
}

ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& request) {
ScoringContextRef CombineLabelScorer::extendedScoringContextInternal(Request const& request) {
auto combineContext = dynamic_cast<const CombineScoringContext*>(request.context.get());

std::vector<ScoringContextRef> extScoringContexts;
Expand Down Expand Up @@ -101,7 +101,7 @@ void CombineLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
}
}

std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTime(Request const& request) {
std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTimeInternal(Request const& request) {
// Initialize accumulated result with zero-valued score and timestep
ScoreWithTime accumResult{0.0, 0};

Expand Down Expand Up @@ -130,7 +130,11 @@ std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTi
return accumResult;
}

std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWithTimes(std::vector<Request> const& requests) {
std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWithTimesInternal(std::vector<Request> const& requests) {
if (requests.empty()) {
return ScoresWithTimes{};
}

// Initialize accumulated results with zero-valued scores and timesteps
ScoresWithTimes accumResult{std::vector<Score>(requests.size(), 0.0), {requests.size(), 0}};

Expand Down
18 changes: 9 additions & 9 deletions src/Nn/LabelScorer/CombineLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ public:
// Combine initial ScoringContexts from all sub-scorers
ScoringContextRef getInitialScoringContext() override;

// Combine extended ScoringContexts from all sub-scorers
ScoringContextRef extendedScoringContext(Request const& request) override;

// Cleanup all sub-scorers
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;

Expand All @@ -58,19 +55,22 @@ public:
// Add inputs to all sub-scorers
virtual void addInputs(DataView const& input, size_t nTimesteps) override;

// Compute weighted score of request with all sub-scorers
std::optional<ScoreWithTime> computeScoreWithTime(Request const& request) override;

// Compute weighted scores of requests with all sub-scorers
std::optional<ScoresWithTimes> computeScoresWithTimes(std::vector<Request> const& requests) override;

protected:
struct ScaledLabelScorer {
Core::Ref<LabelScorer> scorer;
Score scale;
};

std::vector<ScaledLabelScorer> scaledScorers_;

// Combine extended ScoringContexts from all sub-scorers
ScoringContextRef extendedScoringContextInternal(Request const& request) override;

// Compute weighted score of request with all sub-scorers
std::optional<ScoreWithTime> computeScoreWithTimeInternal(Request const& request) override;

// Compute weighted scores of requests with all sub-scorers
std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests) override;
};

} // namespace Nn
Expand Down
10 changes: 7 additions & 3 deletions src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ScoringContextRef EncoderDecoderLabelScorer::getInitialScoringContext() {
return decoder_->getInitialScoringContext();
}

ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContext(Request const& request) {
ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContextInternal(Request const& request) {
return decoder_->extendedScoringContext(request);
}

Expand All @@ -59,11 +59,15 @@ void EncoderDecoderLabelScorer::signalNoMoreFeatures() {
decoder_->signalNoMoreFeatures();
}

std::optional<LabelScorer::ScoreWithTime> EncoderDecoderLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
std::optional<LabelScorer::ScoreWithTime> EncoderDecoderLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
return decoder_->computeScoreWithTime(request);
}

std::optional<LabelScorer::ScoresWithTimes> EncoderDecoderLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
std::optional<LabelScorer::ScoresWithTimes> EncoderDecoderLabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
if (requests.empty()) {
return ScoresWithTimes{};
}

return decoder_->computeScoresWithTimes(requests);
}

Expand Down
11 changes: 6 additions & 5 deletions src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ public:
// Get start context from decoder component
ScoringContextRef getInitialScoringContext() override;

// Get extended context from decoder component
ScoringContextRef extendedScoringContext(Request const& request) override;

// Cleanup decoder component. Encoder is "self-cleaning" already in that it only stores outputs until they are
// retrieved.
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
Expand All @@ -60,11 +57,15 @@ public:
// Same as `addInput` but adds features for multiple timesteps at once
void addInputs(DataView const& input, size_t nTimesteps) override;

protected:
// Get extended context from decoder component
ScoringContextRef extendedScoringContextInternal(Request const& request) override;

// Run request through decoder component
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTime(LabelScorer::Request const& request) override;
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;

// Run requests through decoder component
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) override;
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;

private:
Core::Ref<Encoder> encoder_;
Expand Down
10 changes: 7 additions & 3 deletions src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ ScoringContextRef FixedContextOnnxLabelScorer::getInitialScoringContext() {
return hist;
}

ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) {
ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) {
SeqStepScoringContextRef context(dynamic_cast<const SeqStepScoringContext*>(request.context.get()));

bool pushToken = false;
Expand Down Expand Up @@ -159,7 +159,11 @@ void FixedContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringCon
}
}

std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
if (requests.empty()) {
return ScoresWithTimes{};
}

ScoresWithTimes result;
result.scores.reserve(requests.size());

Expand Down Expand Up @@ -232,7 +236,7 @@ std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::compute
return result;
}

std::optional<LabelScorer::ScoreWithTime> FixedContextOnnxLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
std::optional<LabelScorer::ScoreWithTime> FixedContextOnnxLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
auto result = computeScoresWithTimes({request});
if (not result.has_value()) {
return {};
Expand Down
18 changes: 9 additions & 9 deletions src/Nn/LabelScorer/FixedContextOnnxLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,24 @@ public:
// Initial scoring context contains step 0 and a history vector filled with the start label index
ScoringContextRef getInitialScoringContext() override;

// Clean up input buffer as well as cached score vectors that are no longer needed
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;

protected:
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;

// May increment the step by 1 (except for vertical transitions) and may append the next token to the
// history label sequence depending on the transition type and whether loops/blanks update the history
// or not
ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override;

// Clean up input buffer as well as cached score vectors that are no longer needed
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
ScoringContextRef extendedScoringContextInternal(LabelScorer::Request const& request) override;

// If scores for the given scoring contexts are not yet cached, prepare and run an ONNX session to
// compute the scores and cache them
// Then, retreive scores from cache
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) override;
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;

// Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTime(LabelScorer::Request const& request) override;

protected:
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;

private:
// Forward a batch of histories through the ONNX model and put the resulting scores into the score cache
Expand Down
81 changes: 80 additions & 1 deletion src/Nn/LabelScorer/LabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,30 @@ namespace Nn {
* =============================
*/

const Core::ParameterStringVector LabelScorer::paramIgnoredTransitionTypes(
"ignored-transition-types",
"Transition types that should be ignored by the label scorer (i.e. get assigned score 0 and do not affect the ScoringContext)",
",");

LabelScorer::LabelScorer(const Core::Configuration& config)
: Core::Component(config) {}
: Core::Component(config),
ignoredTransitionTypes_() {
auto ignoredTransitionTypeStrings = paramIgnoredTransitionTypes(config);
for (auto const& transitionTypeString : ignoredTransitionTypeStrings) {
bool identifierFound = false;
for (auto const& [stringIdentifier, enumValue] : transitionTypeArray_) {
if (stringIdentifier == transitionTypeString) {
ignoredTransitionTypes_.insert(enumValue);
identifierFound = true;
break;
}
}

if (not identifierFound) {
error() << "Ignored transition type name '" << transitionTypeString << "' is not a valid identifier";
}
}
}

void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
auto featureSize = input.size() / nTimesteps;
Expand All @@ -33,7 +55,64 @@ void LabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
}
}

ScoringContextRef LabelScorer::extendedScoringContext(Request const& request) {
if (ignoredTransitionTypes_.contains(request.transitionType)) {
return request.context;
}
return extendedScoringContextInternal(request);
}

std::optional<LabelScorer::ScoreWithTime> LabelScorer::computeScoreWithTime(Request const& request) {
if (ignoredTransitionTypes_.contains(request.transitionType)) {
return ScoreWithTime{0.0, 0};
}
return computeScoreWithTimeInternal(request);
}

std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
// First, collect all requests for which the transition type is not ignored
std::vector<Request> nonIgnoredRequests;
std::unordered_set<size_t> nonIgnoredRequestIndices;
nonIgnoredRequests.reserve(requests.size());
nonIgnoredRequestIndices.reserve(requests.size());

for (size_t requestIndex = 0ul; requestIndex < requests.size(); ++requestIndex) {
auto const& request = requests[requestIndex];
if (not ignoredTransitionTypes_.contains(request.transitionType)) {
nonIgnoredRequests.push_back(request);
nonIgnoredRequestIndices.emplace(requestIndex);
}
}

// Compute scores for non-ignored requests
auto nonIgnoredResults = computeScoresWithTimesInternal(nonIgnoredRequests);
if (not nonIgnoredResults) {
return {};
}

// Interleave actual results with 0 scores for requests with ignored transition types
ScoresWithTimes result;
size_t nonIgnoredResultsIdx = 0ul;
for (size_t requestIndex = 0ul; requestIndex < requests.size(); ++requestIndex) {
if (nonIgnoredRequestIndices.contains(requestIndex)) {
result.scores.push_back(nonIgnoredResults->scores[nonIgnoredResultsIdx]);
result.timeframes.push_back(nonIgnoredResults->timeframes[nonIgnoredResultsIdx]);
++nonIgnoredResultsIdx;
}
else {
result.scores.push_back(0.0);
result.timeframes.push_back(0);
}
}

return result;
}

std::optional<LabelScorer::ScoresWithTimes> LabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
if (requests.empty()) {
return ScoresWithTimes{};
}

// By default, just loop over the non-batched `computeScoreWithTime` and collect the results
ScoresWithTimes result;

Expand Down
19 changes: 16 additions & 3 deletions src/Nn/LabelScorer/LabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ namespace Nn {
*/
class LabelScorer : public virtual Core::Component,
public Core::ReferenceCounted {
static const Core::ParameterStringVector paramIgnoredTransitionTypes;

public:
typedef Search::Score Score;

Expand Down Expand Up @@ -120,7 +122,7 @@ public:
virtual ScoringContextRef getInitialScoringContext() = 0;

// Creates a copy of the context in the request that is extended using the given token and transition type
virtual ScoringContextRef extendedScoringContext(Request const& request) = 0;
virtual ScoringContextRef extendedScoringContext(Request const& request);

// Given a collection of currently active contexts, this function can clean up values in any internal caches
// or buffers that are saved for scoring contexts which no longer are active.
Expand All @@ -136,12 +138,11 @@ public:
// Return score and timeframe index of the corresponding output
// May not return a value if the LabelScorer is not ready to score the request yet
// (e.g. not enough features received)
virtual std::optional<ScoreWithTime> computeScoreWithTime(Request const& request) = 0;
virtual std::optional<ScoreWithTime> computeScoreWithTime(Request const& request);

// Perform scoring computation for a batch of requests
// May be implemented more efficiently than iterated calls of `getScoreWithTime`
// Return two vectors: one vector with scores and one vector with times
// By default loops over the single-request version
virtual std::optional<ScoresWithTimes> computeScoresWithTimes(std::vector<Request> const& requests);

protected:
Expand All @@ -155,6 +156,18 @@ protected:
{"initial-blank", INITIAL_BLANK},
});
static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values");

// The public versions of these functions are implemented in this base class and handle the ignoring of transition types.
// These `Internal` versions contain the actual logic and should be overridden in child classes.

virtual ScoringContextRef extendedScoringContextInternal(Request const& request) = 0;
virtual std::optional<ScoreWithTime> computeScoreWithTimeInternal(Request const& request) = 0;

// By default loops over the single-request version
virtual std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests);

private:
std::unordered_set<TransitionType> ignoredTransitionTypes_;
};

} // namespace Nn
Expand Down
10 changes: 7 additions & 3 deletions src/Nn/LabelScorer/NoContextOnnxLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ ScoringContextRef NoContextOnnxLabelScorer::getInitialScoringContext() {
return Core::ref(new StepScoringContext());
}

ScoringContextRef NoContextOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) {
ScoringContextRef NoContextOnnxLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) {
StepScoringContextRef context(dynamic_cast<const StepScoringContext*>(request.context.get()));
return Core::ref(new StepScoringContext(context->currentStep + 1));
}
Expand All @@ -73,7 +73,11 @@ void NoContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContex
}
}

std::optional<LabelScorer::ScoresWithTimes> NoContextOnnxLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
std::optional<LabelScorer::ScoresWithTimes> NoContextOnnxLabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
if (requests.empty()) {
return ScoresWithTimes{};
}

ScoresWithTimes result;
result.scores.reserve(requests.size());

Expand Down Expand Up @@ -115,7 +119,7 @@ std::optional<LabelScorer::ScoresWithTimes> NoContextOnnxLabelScorer::computeSco
return result;
}

std::optional<LabelScorer::ScoreWithTime> NoContextOnnxLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
std::optional<LabelScorer::ScoreWithTime> NoContextOnnxLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
auto result = computeScoresWithTimes({request});
if (not result.has_value()) {
return {};
Expand Down
16 changes: 8 additions & 8 deletions src/Nn/LabelScorer/NoContextOnnxLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,22 @@ public:
// Initial scoring context contains step 0
ScoringContextRef getInitialScoringContext() override;

// Increment the step by 1
ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override;

// Clean up input buffer as well as cached score vectors that are no longer needed
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;

protected:
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;

// Increment the step by 1
ScoringContextRef extendedScoringContextInternal(LabelScorer::Request const& request) override;

// If scores for the given scoring contexts are not yet cached, prepare and run an ONNX session to
// compute the scores and cache them
// Then, retreive scores from cache
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) override;
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;

// Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTime(LabelScorer::Request const& request) override;

protected:
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;

private:
void forwardContext(StepScoringContextRef const& context);
Expand Down
Loading