Skip to content

Commit 7706445

Browse files
hannah220curufinwe
andauthored
Add getCurrentBestLatticeTraceback method to SearchAlgorithmV2 and derived classes (#157)
Co-authored-by: Eugen Beck <[email protected]>
1 parent 5d381e7 commit 7706445

File tree

6 files changed

+76
-37
lines changed

6 files changed

+76
-37
lines changed

src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,6 @@ void LexiconfreeTimesyncBeamSearch::reset() {
205205
initializationTime_.stop();
206206
}
207207

208-
Core::Ref<LatticeTrace> LexiconfreeTimesyncBeamSearch::getRootTrace() const {
209-
return rootTrace_;
210-
}
211-
212208
void LexiconfreeTimesyncBeamSearch::enterSegment(Bliss::SpeechSegment const* segment) {
213209
initializationTime_.start();
214210
labelScorer_->reset();
@@ -239,10 +235,18 @@ void LexiconfreeTimesyncBeamSearch::putFeatures(Nn::DataView const& features, si
239235
featureProcessingTime_.stop();
240236
}
241237

238+
Core::Ref<LatticeTrace> LexiconfreeTimesyncBeamSearch::getRootTrace() const {
239+
return rootTrace_;
240+
}
241+
242242
Core::Ref<const Traceback> LexiconfreeTimesyncBeamSearch::getCurrentBestTraceback() const {
243243
return getBestHypothesis().trace->performTraceback();
244244
}
245245

246+
Core::Ref<const LatticeTraceback> LexiconfreeTimesyncBeamSearch::getCurrentBestLatticeTraceback() const {
247+
return performLatticeTraceback(getBestHypothesis().trace);
248+
}
249+
246250
Core::Ref<const LatticeAdaptor> LexiconfreeTimesyncBeamSearch::getCurrentBestWordLattice() const {
247251
auto& bestHypothesis = getBestHypothesis();
248252
LatticeTrace endTrace(bestHypothesis.trace, 0, bestHypothesis.trace->time + 1, bestHypothesis.trace->score, {});

src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,21 @@ public:
5151

5252
// Inherited methods from `SearchAlgorithmV2`
5353

54-
Speech::ModelCombination::Mode requiredModelCombination() const override;
55-
bool setModelCombination(Speech::ModelCombination const& modelCombination) override;
56-
void reset() override;
57-
Core::Ref<LatticeTrace> getRootTrace() const override;
58-
void enterSegment(Bliss::SpeechSegment const* = nullptr) override;
59-
void finishSegment() override;
60-
void putFeature(Nn::DataView const& feature) override;
61-
void putFeatures(Nn::DataView const& features, size_t nTimesteps) override;
62-
Core::Ref<const Traceback> getCurrentBestTraceback() const override;
63-
Core::Ref<const LatticeAdaptor> getCurrentBestWordLattice() const override;
64-
Core::Ref<LatticeTrace> getCommonPrefix() const override;
65-
bool decodeStep() override;
54+
Speech::ModelCombination::Mode requiredModelCombination() const override;
55+
bool setModelCombination(Speech::ModelCombination const& modelCombination) override;
56+
void reset() override;
57+
void enterSegment(Bliss::SpeechSegment const* = nullptr) override;
58+
void finishSegment() override;
59+
void putFeature(Nn::DataView const& feature) override;
60+
void putFeatures(Nn::DataView const& features, size_t nTimesteps) override;
61+
62+
Core::Ref<LatticeTrace> getRootTrace() const override;
63+
Core::Ref<const Traceback> getCurrentBestTraceback() const override;
64+
Core::Ref<const LatticeTraceback> getCurrentBestLatticeTraceback() const override;
65+
Core::Ref<const LatticeAdaptor> getCurrentBestWordLattice() const override;
66+
Core::Ref<LatticeTrace> getCommonPrefix() const override;
67+
68+
bool decodeStep() override;
6669

6770
protected:
6871
/*

src/Search/SearchV2.hh

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ public:
6969
// Cleans up buffers, hypotheses, flags etc. from the previous segment recognition.
7070
virtual void reset() = 0;
7171

72-
// Return the first trace of all hypotheses. Needed for computing partial trace.
73-
virtual Core::Ref<LatticeTrace> getRootTrace() const = 0;
74-
7572
// Signal the beginning of a new audio segment.
7673
virtual void enterSegment(Bliss::SpeechSegment const* = nullptr) = 0;
7774

@@ -84,9 +81,15 @@ public:
8481
// Pass feature vectors for multiple time steps.
8582
virtual void putFeatures(Nn::DataView const& features, size_t nTimesteps) = 0;
8683

87-
// Return the current best traceback. May contain unstable results.
84+
// Return the first trace of all hypotheses. Needed for computing partial trace.
85+
virtual Core::Ref<LatticeTrace> getRootTrace() const = 0;
86+
87+
// Return the current best traceback of TracebackItem. May contain unstable results.
8888
virtual Core::Ref<const Traceback> getCurrentBestTraceback() const = 0;
8989

90+
// Return the current best traceback of Ref LatticeTrace. May contain unstable results.
91+
virtual Core::Ref<const LatticeTraceback> getCurrentBestLatticeTraceback() const = 0;
92+
9093
// Similar to `getCurrentBestTraceback` but return the lattice instead of just single-best traceback.
9194
virtual Core::Ref<const LatticeAdaptor> getCurrentBestWordLattice() const = 0;
9295

src/Search/Traceback.hh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,28 @@ public:
158158
u32 wordCount() const;
159159
};
160160

161+
/*
162+
* Vector of Refs to LatticeTrace
163+
*/
164+
class LatticeTraceback : public Core::ReferenceCounted, public std::vector<Core::Ref<LatticeTrace>> {
165+
};
166+
167+
/*
168+
* Perform traceback on the given LatticeTrace reference. Returns a vector reference containing the best path
169+
* ending in the given trace.
170+
*/
171+
inline Core::Ref<const LatticeTraceback> performLatticeTraceback(Core::Ref<LatticeTrace> trace) {
172+
LatticeTraceback* traceback = new LatticeTraceback();
173+
174+
while (trace) {
175+
traceback->push_back(trace);
176+
trace = trace->predecessor;
177+
}
178+
179+
std::reverse(traceback->begin(), traceback->end());
180+
return Core::Ref<const LatticeTraceback>(traceback);
181+
}
182+
161183
} // namespace Search
162184

163185
#endif // TRACEBACK_HH

src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,6 @@ void TreeTimesyncBeamSearch::reset() {
257257
initializationTime_.stop();
258258
}
259259

260-
Core::Ref<LatticeTrace> TreeTimesyncBeamSearch::getRootTrace() const {
261-
return rootTrace_;
262-
}
263-
264260
void TreeTimesyncBeamSearch::enterSegment(Bliss::SpeechSegment const* segment) {
265261
initializationTime_.start();
266262
labelScorer_->reset();
@@ -297,10 +293,18 @@ void TreeTimesyncBeamSearch::putFeatures(Nn::DataView const& features, size_t nT
297293
featureProcessingTime_.stop();
298294
}
299295

296+
Core::Ref<LatticeTrace> TreeTimesyncBeamSearch::getRootTrace() const {
297+
return rootTrace_;
298+
}
299+
300300
Core::Ref<const Traceback> TreeTimesyncBeamSearch::getCurrentBestTraceback() const {
301301
return getBestHypothesis().trace->performTraceback();
302302
}
303303

304+
Core::Ref<const LatticeTraceback> TreeTimesyncBeamSearch::getCurrentBestLatticeTraceback() const {
305+
return performLatticeTraceback(getBestHypothesis().trace);
306+
}
307+
304308
Core::Ref<const LatticeAdaptor> TreeTimesyncBeamSearch::getCurrentBestWordLattice() const {
305309
auto& bestHypothesis = getBestHypothesis();
306310
LatticeTrace endTrace(bestHypothesis.trace, 0, bestHypothesis.trace->time + 1, bestHypothesis.trace->score, {});

src/Search/TreeTimesyncBeamSearch/TreeTimesyncBeamSearch.hh

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,22 @@ public:
5656

5757
// Inherited methods from `SearchAlgorithmV2`
5858

59-
Speech::ModelCombination::Mode requiredModelCombination() const override;
60-
Am::AcousticModel::Mode requiredAcousticModel() const override;
61-
bool setModelCombination(Speech::ModelCombination const& modelCombination) override;
62-
void reset() override;
63-
Core::Ref<LatticeTrace> getRootTrace() const override;
64-
void enterSegment(Bliss::SpeechSegment const* = nullptr) override;
65-
void finishSegment() override;
66-
void putFeature(Nn::DataView const& feature) override;
67-
void putFeatures(Nn::DataView const& features, size_t nTimesteps) override;
68-
Core::Ref<const Traceback> getCurrentBestTraceback() const override;
69-
Core::Ref<const LatticeAdaptor> getCurrentBestWordLattice() const override;
70-
Core::Ref<LatticeTrace> getCommonPrefix() const override;
71-
bool decodeStep() override;
59+
Speech::ModelCombination::Mode requiredModelCombination() const override;
60+
Am::AcousticModel::Mode requiredAcousticModel() const override;
61+
bool setModelCombination(Speech::ModelCombination const& modelCombination) override;
62+
void reset() override;
63+
void enterSegment(Bliss::SpeechSegment const* = nullptr) override;
64+
void finishSegment() override;
65+
void putFeature(Nn::DataView const& feature) override;
66+
void putFeatures(Nn::DataView const& features, size_t nTimesteps) override;
67+
68+
Core::Ref<LatticeTrace> getRootTrace() const override;
69+
Core::Ref<const Traceback> getCurrentBestTraceback() const override;
70+
Core::Ref<const LatticeTraceback> getCurrentBestLatticeTraceback() const override;
71+
Core::Ref<const LatticeAdaptor> getCurrentBestWordLattice() const override;
72+
Core::Ref<LatticeTrace> getCommonPrefix() const override;
73+
74+
bool decodeStep() override;
7275

7376
protected:
7477
/*

0 commit comments

Comments
 (0)