Skip to content

Commit 54dabf1

Browse files
SimBe195curufinwe
andauthored
Update Nn::LabelScorer: allow enabling of specific transition types (#148)
Co-authored-by: Eugen Beck <[email protected]>
1 parent c612950 commit 54dabf1

19 files changed

+375
-167
lines changed

src/Core/CollapsedVector.hh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public:
4545
inline void push_back(const T& value);
4646
inline const T& operator[](size_t idx) const;
4747
inline const T& at(size_t idx) const;
48+
inline void set(size_t idx, const T& value);
4849
inline size_t size() const noexcept;
4950
inline void clear() noexcept;
5051
inline void reserve(size_t size);
@@ -105,6 +106,21 @@ inline const T& CollapsedVector<T>::at(size_t idx) const {
105106
return (*this)[idx];
106107
}
107108

109+
template<typename T>
110+
inline void CollapsedVector<T>::set(size_t idx, const T& value) {
111+
if (idx >= logicalSize_) {
112+
throw std::out_of_range("Trying to access illegal index of CollapsedVector");
113+
}
114+
if (data_.size() != 1ul) {
115+
data_[idx] = value;
116+
data_.push_back(value);
117+
}
118+
else if (value != data_.front()) {
119+
data_.resize(logicalSize_, data_.front());
120+
data_[idx] = value;
121+
}
122+
}
123+
108124
template<typename T>
109125
inline size_t CollapsedVector<T>::size() const noexcept {
110126
return logicalSize_;

src/Nn/LabelScorer/BufferedLabelScorer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
namespace Nn {
1919

20-
BufferedLabelScorer::BufferedLabelScorer(Core::Configuration const& config)
20+
BufferedLabelScorer::BufferedLabelScorer(Core::Configuration const& config, TransitionPresetType defaultPreset)
2121
: Core::Component(config),
22-
Precursor(config),
22+
Precursor(config, defaultPreset),
2323
expectMoreFeatures_(true),
2424
inputBuffer_(),
2525
numDeletedInputs_(0ul) {

src/Nn/LabelScorer/BufferedLabelScorer.hh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class BufferedLabelScorer : public LabelScorer {
3131
public:
3232
using Precursor = LabelScorer;
3333

34-
BufferedLabelScorer(Core::Configuration const& config);
34+
BufferedLabelScorer(Core::Configuration const& config, TransitionPresetType defaultPreset);
3535

3636
// Prepares the LabelScorer to receive new inputs by resetting input buffer, timeframe buffer
3737
// and segment end flag

src/Nn/LabelScorer/CombineLabelScorer.cc

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ Core::ParameterFloat CombineLabelScorer::paramScale(
2525
"scale", "Scores of a sub-label-scorer are scaled by this factor", 1.0);
2626

2727
CombineLabelScorer::CombineLabelScorer(Core::Configuration const& config)
28-
: Core::Component(config), Precursor(config), scaledScorers_() {
28+
: Core::Component(config),
29+
Precursor(config, TransitionPresetType::ALL) {
2930
size_t numLabelScorers = paramNumLabelScorers(config);
3031
for (size_t i = 0ul; i < numLabelScorers; ++i) {
3132
Core::Configuration subConfig = select(std::string("scorer-") + std::to_string(i + 1));
@@ -55,22 +56,6 @@ ScoringContextRef CombineLabelScorer::getInitialScoringContext() {
5556
return Core::ref(new CombineScoringContext(std::move(scoringContexts)));
5657
}
5758

58-
ScoringContextRef CombineLabelScorer::extendedScoringContext(Request const& request) {
59-
auto combineContext = dynamic_cast<const CombineScoringContext*>(request.context.get());
60-
61-
std::vector<ScoringContextRef> extScoringContexts;
62-
extScoringContexts.reserve(scaledScorers_.size());
63-
64-
auto scorerIt = scaledScorers_.begin();
65-
auto contextIt = combineContext->scoringContexts.begin();
66-
67-
for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) {
68-
Request subRequest{*contextIt, request.nextToken, request.transitionType};
69-
extScoringContexts.push_back(scorerIt->scorer->extendedScoringContext(subRequest));
70-
}
71-
return Core::ref(new CombineScoringContext(std::move(extScoringContexts)));
72-
}
73-
7459
void CombineLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
7560
std::vector<const CombineScoringContext*> combineContexts;
7661
combineContexts.reserve(activeContexts.internalSize());
@@ -101,7 +86,23 @@ void CombineLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
10186
}
10287
}
10388

104-
std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTime(Request const& request) {
89+
ScoringContextRef CombineLabelScorer::extendedScoringContextInternal(Request const& request) {
90+
auto combineContext = dynamic_cast<const CombineScoringContext*>(request.context.get());
91+
92+
std::vector<ScoringContextRef> extScoringContexts;
93+
extScoringContexts.reserve(scaledScorers_.size());
94+
95+
auto scorerIt = scaledScorers_.begin();
96+
auto contextIt = combineContext->scoringContexts.begin();
97+
98+
for (; scorerIt != scaledScorers_.end(); ++scorerIt, ++contextIt) {
99+
Request subRequest{*contextIt, request.nextToken, request.transitionType};
100+
extScoringContexts.push_back(scorerIt->scorer->extendedScoringContext(subRequest));
101+
}
102+
return Core::ref(new CombineScoringContext(std::move(extScoringContexts)));
103+
}
104+
105+
std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTimeInternal(Request const& request) {
105106
// Initialize accumulated result with zero-valued score and timestep
106107
ScoreWithTime accumResult{0.0, 0};
107108

@@ -130,7 +131,11 @@ std::optional<LabelScorer::ScoreWithTime> CombineLabelScorer::computeScoreWithTi
130131
return accumResult;
131132
}
132133

133-
std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWithTimes(std::vector<Request> const& requests) {
134+
std::optional<LabelScorer::ScoresWithTimes> CombineLabelScorer::computeScoresWithTimesInternal(std::vector<Request> const& requests) {
135+
if (requests.empty()) {
136+
return ScoresWithTimes{};
137+
}
138+
134139
// Initialize accumulated results with zero-valued scores and timesteps
135140
ScoresWithTimes accumResult{std::vector<Score>(requests.size(), 0.0), {requests.size(), 0}};
136141

src/Nn/LabelScorer/CombineLabelScorer.hh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ public:
4646
// Combine initial ScoringContexts from all sub-scorers
4747
ScoringContextRef getInitialScoringContext() override;
4848

49-
// Combine extended ScoringContexts from all sub-scorers
50-
ScoringContextRef extendedScoringContext(Request const& request) override;
51-
5249
// Cleanup all sub-scorers
5350
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
5451

@@ -58,19 +55,22 @@ public:
5855
// Add inputs to all sub-scorers
5956
virtual void addInputs(DataView const& input, size_t nTimesteps) override;
6057

61-
// Compute weighted score of request with all sub-scorers
62-
std::optional<ScoreWithTime> computeScoreWithTime(Request const& request) override;
63-
64-
// Compute weighted scores of requests with all sub-scorers
65-
std::optional<ScoresWithTimes> computeScoresWithTimes(std::vector<Request> const& requests) override;
66-
6758
protected:
6859
struct ScaledLabelScorer {
6960
Core::Ref<LabelScorer> scorer;
7061
Score scale;
7162
};
7263

7364
std::vector<ScaledLabelScorer> scaledScorers_;
65+
66+
// Combine extended ScoringContexts from all sub-scorers
67+
ScoringContextRef extendedScoringContextInternal(Request const& request) override;
68+
69+
// Compute weighted score of request with all sub-scorers
70+
std::optional<ScoreWithTime> computeScoreWithTimeInternal(Request const& request) override;
71+
72+
// Compute weighted scores of requests with all sub-scorers
73+
std::optional<ScoresWithTimes> computeScoresWithTimesInternal(std::vector<Request> const& requests) override;
7474
};
7575

7676
} // namespace Nn

src/Nn/LabelScorer/EncoderDecoderLabelScorer.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace Nn {
1919

2020
EncoderDecoderLabelScorer::EncoderDecoderLabelScorer(Core::Configuration const& config, Core::Ref<Encoder> const& encoder, Core::Ref<LabelScorer> const& decoder)
2121
: Core::Component(config),
22-
LabelScorer(config),
22+
LabelScorer(config, TransitionPresetType::ALL),
2323
encoder_(encoder),
2424
decoder_(decoder) {
2525
}
@@ -33,10 +33,6 @@ ScoringContextRef EncoderDecoderLabelScorer::getInitialScoringContext() {
3333
return decoder_->getInitialScoringContext();
3434
}
3535

36-
ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContext(Request const& request) {
37-
return decoder_->extendedScoringContext(request);
38-
}
39-
4036
void EncoderDecoderLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
4137
decoder_->cleanupCaches(activeContexts);
4238
}
@@ -59,11 +55,19 @@ void EncoderDecoderLabelScorer::signalNoMoreFeatures() {
5955
decoder_->signalNoMoreFeatures();
6056
}
6157

62-
std::optional<LabelScorer::ScoreWithTime> EncoderDecoderLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
58+
ScoringContextRef EncoderDecoderLabelScorer::extendedScoringContextInternal(Request const& request) {
59+
return decoder_->extendedScoringContext(request);
60+
}
61+
62+
std::optional<LabelScorer::ScoreWithTime> EncoderDecoderLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
6363
return decoder_->computeScoreWithTime(request);
6464
}
6565

66-
std::optional<LabelScorer::ScoresWithTimes> EncoderDecoderLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
66+
std::optional<LabelScorer::ScoresWithTimes> EncoderDecoderLabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
67+
if (requests.empty()) {
68+
return ScoresWithTimes{};
69+
}
70+
6771
return decoder_->computeScoresWithTimes(requests);
6872
}
6973

src/Nn/LabelScorer/EncoderDecoderLabelScorer.hh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ public:
4646
// Get start context from decoder component
4747
ScoringContextRef getInitialScoringContext() override;
4848

49-
// Get extended context from decoder component
50-
ScoringContextRef extendedScoringContext(Request const& request) override;
51-
5249
// Cleanup decoder component. Encoder is "self-cleaning" already in that it only stores outputs until they are
5350
// retrieved.
5451
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
@@ -60,11 +57,15 @@ public:
6057
// Same as `addInput` but adds features for multiple timesteps at once
6158
void addInputs(DataView const& input, size_t nTimesteps) override;
6259

60+
protected:
61+
// Get extended context from decoder component
62+
ScoringContextRef extendedScoringContextInternal(Request const& request) override;
63+
6364
// Run request through decoder component
64-
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTime(LabelScorer::Request const& request) override;
65+
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;
6566

6667
// Run requests through decoder component
67-
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) override;
68+
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;
6869

6970
private:
7071
Core::Ref<Encoder> encoder_;

src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ static const std::vector<Onnx::IOSpecification> ioSpec = {
7272

7373
FixedContextOnnxLabelScorer::FixedContextOnnxLabelScorer(Core::Configuration const& config)
7474
: Core::Component(config),
75-
Precursor(config),
75+
Precursor(config, TransitionPresetType::TRANSDUCER),
7676
startLabelIndex_(paramStartLabelIndex(config)),
7777
historyLength_(paramHistoryLength(config)),
7878
blankUpdatesHistory_(paramBlankUpdatesHistory(config)),
@@ -97,7 +97,32 @@ ScoringContextRef FixedContextOnnxLabelScorer::getInitialScoringContext() {
9797
return hist;
9898
}
9999

100-
ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) {
100+
size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const {
101+
auto minTimeIndex = Core::Type<Speech::TimeframeIndex>::max;
102+
for (auto const& context : activeContexts.internalData()) {
103+
SeqStepScoringContextRef stepHistory(dynamic_cast<const SeqStepScoringContext*>(context.get()));
104+
minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep);
105+
}
106+
107+
return minTimeIndex;
108+
}
109+
110+
void FixedContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
111+
Precursor::cleanupCaches(activeContexts);
112+
113+
std::unordered_set<ScoringContextRef, ScoringContextHash, ScoringContextEq> activeContextSet(activeContexts.internalData().begin(), activeContexts.internalData().end());
114+
115+
for (auto it = scoreCache_.begin(); it != scoreCache_.end();) {
116+
if (activeContextSet.find(it->first) == activeContextSet.end()) {
117+
it = scoreCache_.erase(it);
118+
}
119+
else {
120+
++it;
121+
}
122+
}
123+
}
124+
125+
ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContextInternal(LabelScorer::Request const& request) {
101126
SeqStepScoringContextRef context(dynamic_cast<const SeqStepScoringContext*>(request.context.get()));
102127

103128
bool pushToken = false;
@@ -144,22 +169,11 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScore
144169
return Core::ref(new SeqStepScoringContext(std::move(newLabelSeq), context->currentStep + timeIncrement));
145170
}
146171

147-
void FixedContextOnnxLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
148-
Precursor::cleanupCaches(activeContexts);
149-
150-
std::unordered_set<ScoringContextRef, ScoringContextHash, ScoringContextEq> activeContextSet(activeContexts.internalData().begin(), activeContexts.internalData().end());
151-
152-
for (auto it = scoreCache_.begin(); it != scoreCache_.end();) {
153-
if (activeContextSet.find(it->first) == activeContextSet.end()) {
154-
it = scoreCache_.erase(it);
155-
}
156-
else {
157-
++it;
158-
}
172+
std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) {
173+
if (requests.empty()) {
174+
return ScoresWithTimes{};
159175
}
160-
}
161176

162-
std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
163177
ScoresWithTimes result;
164178
result.scores.reserve(requests.size());
165179

@@ -232,24 +246,14 @@ std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::compute
232246
return result;
233247
}
234248

235-
std::optional<LabelScorer::ScoreWithTime> FixedContextOnnxLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
249+
std::optional<LabelScorer::ScoreWithTime> FixedContextOnnxLabelScorer::computeScoreWithTimeInternal(LabelScorer::Request const& request) {
236250
auto result = computeScoresWithTimes({request});
237251
if (not result.has_value()) {
238252
return {};
239253
}
240254
return ScoreWithTime{result->scores.front(), result->timeframes.front()};
241255
}
242256

243-
size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const {
244-
auto minTimeIndex = Core::Type<Speech::TimeframeIndex>::max;
245-
for (auto const& context : activeContexts.internalData()) {
246-
SeqStepScoringContextRef stepHistory(dynamic_cast<const SeqStepScoringContext*>(context.get()));
247-
minTimeIndex = std::min(minTimeIndex, stepHistory->currentStep);
248-
}
249-
250-
return minTimeIndex;
251-
}
252-
253257
void FixedContextOnnxLabelScorer::forwardBatch(std::vector<SeqStepScoringContextRef> const& contextBatch) {
254258
if (contextBatch.empty()) {
255259
return;

src/Nn/LabelScorer/FixedContextOnnxLabelScorer.hh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,24 +47,24 @@ public:
4747
// Initial scoring context contains step 0 and a history vector filled with the start label index
4848
ScoringContextRef getInitialScoringContext() override;
4949

50+
// Clean up input buffer as well as cached score vectors that are no longer needed
51+
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
52+
53+
protected:
54+
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;
55+
5056
// May increment the step by 1 (except for vertical transitions) and may append the next token to the
5157
// history label sequence depending on the transition type and whether loops/blanks update the history
5258
// or not
53-
ScoringContextRef extendedScoringContext(LabelScorer::Request const& request) override;
54-
55-
// Clean up input buffer as well as cached score vectors that are no longer needed
56-
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;
59+
ScoringContextRef extendedScoringContextInternal(LabelScorer::Request const& request) override;
5760

5861
// If scores for the given scoring contexts are not yet cached, prepare and run an ONNX session to
5962
// compute the scores and cache them
6063
// Then, retreive scores from cache
61-
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) override;
64+
std::optional<LabelScorer::ScoresWithTimes> computeScoresWithTimesInternal(std::vector<LabelScorer::Request> const& requests) override;
6265

6366
// Uses `getScoresWithTimes` internally with some wrapping for vector packing/expansion
64-
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTime(LabelScorer::Request const& request) override;
65-
66-
protected:
67-
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;
67+
std::optional<LabelScorer::ScoreWithTime> computeScoreWithTimeInternal(LabelScorer::Request const& request) override;
6868

6969
private:
7070
// Forward a batch of histories through the ONNX model and put the resulting scores into the score cache

0 commit comments

Comments
 (0)