Skip to content

Commit f0f6a04

Browse files
authored
Add Nn::TransitionLabelScorer (#138)
1 parent d9145dd commit f0f6a04

File tree

5 files changed

+172
-1
lines changed

5 files changed

+172
-1
lines changed

src/Nn/LabelScorer/LabelScorer.hh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public:
8484
BLANK_LOOP,
8585
INITIAL_LABEL,
8686
INITIAL_BLANK,
87+
numTypes, // must remain at the end
8788
};
8889

8990
// Request for scoring or context extension
@@ -142,6 +143,18 @@ public:
142143
// Return two vectors: one vector with scores and one vector with times
143144
// By default loops over the single-request version
144145
virtual std::optional<ScoresWithTimes> computeScoresWithTimes(std::vector<Request> const& requests);
146+
147+
protected:
148+
inline static constexpr auto transitionTypeArray_ = std::to_array<std::pair<std::string_view, TransitionType>>({
149+
{"label-to-label", LABEL_TO_LABEL},
150+
{"label-loop", LABEL_LOOP},
151+
{"label-to-blank", LABEL_TO_BLANK},
152+
{"blank-to-label", BLANK_TO_LABEL},
153+
{"blank-loop", BLANK_LOOP},
154+
{"initial-label", INITIAL_LABEL},
155+
{"initial-blank", INITIAL_BLANK},
156+
});
157+
static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values");
145158
};
146159

147160
} // namespace Nn

src/Nn/LabelScorer/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ LIBSPRINTLABELSCORER_O = \
2222
$(OBJDIR)/NoContextOnnxLabelScorer.o \
2323
$(OBJDIR)/NoOpLabelScorer.o \
2424
$(OBJDIR)/ScoringContext.o \
25-
$(OBJDIR)/StatefulOnnxLabelScorer.o
25+
$(OBJDIR)/StatefulOnnxLabelScorer.o \
26+
$(OBJDIR)/TransitionLabelScorer.o
2627

2728
# -----------------------------------------------------------------------------
2829

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/** Copyright 2025 RWTH Aachen University. All rights reserved.
2+
*
3+
* Licensed under the RWTH ASR License (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#include "TransitionLabelScorer.hh"
17+
18+
#include <Nn/Module.hh>
19+
20+
namespace Nn {
21+
22+
TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config)
23+
: Core::Component(config),
24+
Precursor(config),
25+
transitionScores_(),
26+
baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) {
27+
for (auto const& [stringIdentifier, enumValue] : transitionTypeArray_) {
28+
auto paramName = std::string(stringIdentifier) + "-score";
29+
transitionScores_[enumValue] = Core::ParameterFloat(paramName.c_str(), "", 0.0)(config);
30+
}
31+
}
32+
33+
void TransitionLabelScorer::reset() {
34+
baseLabelScorer_->reset();
35+
}
36+
37+
void TransitionLabelScorer::signalNoMoreFeatures() {
38+
baseLabelScorer_->signalNoMoreFeatures();
39+
}
40+
41+
ScoringContextRef TransitionLabelScorer::getInitialScoringContext() {
42+
return baseLabelScorer_->getInitialScoringContext();
43+
}
44+
45+
ScoringContextRef TransitionLabelScorer::extendedScoringContext(LabelScorer::Request const& request) {
46+
return baseLabelScorer_->extendedScoringContext(request);
47+
}
48+
49+
void TransitionLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
50+
baseLabelScorer_->cleanupCaches(activeContexts);
51+
}
52+
53+
void TransitionLabelScorer::addInput(DataView const& input) {
54+
baseLabelScorer_->addInput(input);
55+
}
56+
57+
void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
58+
baseLabelScorer_->addInputs(input, nTimesteps);
59+
}
60+
61+
std::optional<LabelScorer::ScoreWithTime> TransitionLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
62+
auto result = baseLabelScorer_->computeScoreWithTime(request);
63+
if (result) {
64+
result->score += transitionScores_[request.transitionType];
65+
}
66+
return result;
67+
}
68+
69+
std::optional<LabelScorer::ScoresWithTimes> TransitionLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
70+
auto results = baseLabelScorer_->computeScoresWithTimes(requests);
71+
if (results) {
72+
for (size_t i = 0ul; i < requests.size(); ++i) {
73+
results->scores[i] += transitionScores_[requests[i].transitionType];
74+
}
75+
}
76+
return results;
77+
}
78+
79+
} // namespace Nn
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/** Copyright 2025 RWTH Aachen University. All rights reserved.
2+
*
3+
* Licensed under the RWTH ASR License (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#ifndef TRANSITION_LABEL_SCORER_HH
17+
#define TRANSITION_LABEL_SCORER_HH
18+
19+
#include "LabelScorer.hh"
20+
21+
namespace Nn {
22+
23+
/*
24+
* This LabelScorer wraps a base LabelScorer and adds predefined transition scores
25+
* to the base scores depending on the transition type of each request.
26+
* The transition scores are all individually specified as config parameters.
27+
*/
28+
class TransitionLabelScorer : public LabelScorer {
29+
public:
30+
using Precursor = LabelScorer;
31+
32+
TransitionLabelScorer(Core::Configuration const& config);
33+
virtual ~TransitionLabelScorer() = default;
34+
35+
// Reset base scorer
36+
void reset() override;
37+
38+
// Forward signal to base scorer
39+
void signalNoMoreFeatures() override;
40+
41+
// Initial context of base scorer
42+
ScoringContextRef getInitialScoringContext() override;
43+
44+
// Extend context via base scorer
45+
ScoringContextRef extendedScoringContext(Request const& request) override;
46+
47+
// Clean up base scorer
48+
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
49+
50+
// Add input to base scorer
51+
void addInput(DataView const& input) override;
52+
53+
// Add inputs to sub-scorer
54+
void addInputs(DataView const& input, size_t nTimesteps) override;
55+
56+
// Compute score of base scorer and add transition score based on transition type of the request
57+
std::optional<ScoreWithTime> computeScoreWithTime(Request const& request) override;
58+
59+
// Compute scores of base scorer and add transition scores based on transition types of the requests
60+
std::optional<ScoresWithTimes> computeScoresWithTimes(std::vector<Request> const& requests) override;
61+
62+
private:
63+
std::unordered_map<TransitionType, Score> transitionScores_;
64+
65+
Core::Ref<LabelScorer> baseLabelScorer_;
66+
};
67+
68+
} // namespace Nn
69+
70+
#endif // TRANSITION_LABEL_SCORER_HH

src/Nn/Module.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "LabelScorer/NoContextOnnxLabelScorer.hh"
2525
#include "LabelScorer/NoOpLabelScorer.hh"
2626
#include "LabelScorer/StatefulOnnxLabelScorer.hh"
27+
#include "LabelScorer/TransitionLabelScorer.hh"
2728
#include "Statistics.hh"
2829

2930
#ifdef MODULE_NN
@@ -128,6 +129,13 @@ Module_::Module_()
128129
[](Core::Configuration const& config) {
129130
return Core::ref(new StatefulOnnxLabelScorer(config));
130131
});
132+
133+
// Returns predefined scores based on the transition type of each score request
134+
labelScorerFactory_.registerLabelScorer(
135+
"transition",
136+
[](Core::Configuration const& config) {
137+
return Core::ref(new TransitionLabelScorer(config));
138+
});
131139
};
132140

133141
Module_::~Module_() {

0 commit comments

Comments
 (0)