@@ -72,7 +72,7 @@ static const std::vector<Onnx::IOSpecification> ioSpec = {
7272
7373FixedContextOnnxLabelScorer::FixedContextOnnxLabelScorer (Core::Configuration const & config)
7474 : Core::Component(config),
75- Precursor (config),
75+ Precursor (config, TransitionPresetType::TRANSDUCER ),
7676 startLabelIndex_(paramStartLabelIndex(config)),
7777 historyLength_(paramHistoryLength(config)),
7878 blankUpdatesHistory_(paramBlankUpdatesHistory(config)),
@@ -97,7 +97,32 @@ ScoringContextRef FixedContextOnnxLabelScorer::getInitialScoringContext() {
9797 return hist;
9898}
9999
100- ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext (LabelScorer::Request const & request) {
100+ size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex (Core::CollapsedVector<ScoringContextRef> const & activeContexts) const {
101+ auto minTimeIndex = Core::Type<Speech::TimeframeIndex>::max;
102+ for (auto const & context : activeContexts.internalData ()) {
103+ SeqStepScoringContextRef stepHistory (dynamic_cast <const SeqStepScoringContext*>(context.get ()));
104+ minTimeIndex = std::min (minTimeIndex, stepHistory->currentStep );
105+ }
106+
107+ return minTimeIndex;
108+ }
109+
110+ void FixedContextOnnxLabelScorer::cleanupCaches (Core::CollapsedVector<ScoringContextRef> const & activeContexts) {
111+ Precursor::cleanupCaches (activeContexts);
112+
113+ std::unordered_set<ScoringContextRef, ScoringContextHash, ScoringContextEq> activeContextSet (activeContexts.internalData ().begin (), activeContexts.internalData ().end ());
114+
115+ for (auto it = scoreCache_.begin (); it != scoreCache_.end ();) {
116+ if (activeContextSet.find (it->first ) == activeContextSet.end ()) {
117+ it = scoreCache_.erase (it);
118+ }
119+ else {
120+ ++it;
121+ }
122+ }
123+ }
124+
125+ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContextInternal (LabelScorer::Request const & request) {
101126 SeqStepScoringContextRef context (dynamic_cast <const SeqStepScoringContext*>(request.context .get ()));
102127
103128 bool pushToken = false ;
@@ -144,22 +169,11 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScore
144169 return Core::ref (new SeqStepScoringContext (std::move (newLabelSeq), context->currentStep + timeIncrement));
145170}
146171
147- void FixedContextOnnxLabelScorer::cleanupCaches (Core::CollapsedVector<ScoringContextRef> const & activeContexts) {
148- Precursor::cleanupCaches (activeContexts);
149-
150- std::unordered_set<ScoringContextRef, ScoringContextHash, ScoringContextEq> activeContextSet (activeContexts.internalData ().begin (), activeContexts.internalData ().end ());
151-
152- for (auto it = scoreCache_.begin (); it != scoreCache_.end ();) {
153- if (activeContextSet.find (it->first ) == activeContextSet.end ()) {
154- it = scoreCache_.erase (it);
155- }
156- else {
157- ++it;
158- }
172+ std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::computeScoresWithTimesInternal (std::vector<LabelScorer::Request> const & requests) {
173+ if (requests.empty ()) {
174+ return ScoresWithTimes{};
159175 }
160- }
161176
162- std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::computeScoresWithTimes (std::vector<LabelScorer::Request> const & requests) {
163177 ScoresWithTimes result;
164178 result.scores .reserve (requests.size ());
165179
@@ -232,24 +246,14 @@ std::optional<LabelScorer::ScoresWithTimes> FixedContextOnnxLabelScorer::compute
232246 return result;
233247}
234248
235- std::optional<LabelScorer::ScoreWithTime> FixedContextOnnxLabelScorer::computeScoreWithTime (LabelScorer::Request const & request) {
249+ std::optional<LabelScorer::ScoreWithTime> FixedContextOnnxLabelScorer::computeScoreWithTimeInternal (LabelScorer::Request const & request) {
236250 auto result = computeScoresWithTimes ({request});
237251 if (not result.has_value ()) {
238252 return {};
239253 }
240254 return ScoreWithTime{result->scores .front (), result->timeframes .front ()};
241255}
242256
243- size_t FixedContextOnnxLabelScorer::getMinActiveInputIndex (Core::CollapsedVector<ScoringContextRef> const & activeContexts) const {
244- auto minTimeIndex = Core::Type<Speech::TimeframeIndex>::max;
245- for (auto const & context : activeContexts.internalData ()) {
246- SeqStepScoringContextRef stepHistory (dynamic_cast <const SeqStepScoringContext*>(context.get ()));
247- minTimeIndex = std::min (minTimeIndex, stepHistory->currentStep );
248- }
249-
250- return minTimeIndex;
251- }
252-
253257void FixedContextOnnxLabelScorer::forwardBatch (std::vector<SeqStepScoringContextRef> const & contextBatch) {
254258 if (contextBatch.empty ()) {
255259 return ;
0 commit comments