Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/Nn/LabelScorer/ScoringContext.hh
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,23 @@ struct OnnxHiddenState : public Core::ReferenceCounted {
}
};

typedef Core::Ref<OnnxHiddenState> OnnxHiddenStateRef;
typedef Core::Ref<const OnnxHiddenState> OnnxHiddenStateRef;

/*
* Scoring context consisting of a hidden state.
* Assumes that two hidden states are equal if and only if they were created
* from the same label history.
*/
struct OnnxHiddenStateScoringContext : public ScoringContext {
std::vector<LabelIndex> labelSeq; // Used for hashing
OnnxHiddenStateRef hiddenState;
std::vector<LabelIndex> labelSeq; // Used for hashing
mutable OnnxHiddenStateRef hiddenState;
mutable bool requiresFinalize;

OnnxHiddenStateScoringContext()
: labelSeq(), hiddenState() {}
: labelSeq(), hiddenState(), requiresFinalize(false) {}

OnnxHiddenStateScoringContext(std::vector<LabelIndex> const& labelSeq, OnnxHiddenStateRef state)
: labelSeq(labelSeq), hiddenState(state) {}
: labelSeq(labelSeq), hiddenState(state), requiresFinalize(false) {}

bool isEqual(ScoringContextRef const& other) const;
size_t hash() const;
Expand Down
92 changes: 54 additions & 38 deletions src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const Core::ParameterInt StatefulOnnxLabelScorer::paramMaxBatchSize(

const Core::ParameterInt StatefulOnnxLabelScorer::paramMaxCachedScores(
"max-cached-score-vectors",
"Maximum size of cache that maps histories to scores. This prevents memory overflow in case of very long audio segments.",
"Maximum size of cache that maps scoring contexts to scores. This prevents memory overflow in case of very long audio segments.",
1000);

// Scorer only takes hidden states as input which are not part of the IO spec
Expand Down Expand Up @@ -201,7 +201,7 @@ Core::Ref<const ScoringContext> StatefulOnnxLabelScorer::getInitialScoringContex
}

Core::Ref<const ScoringContext> StatefulOnnxLabelScorer::extendedScoringContext(LabelScorer::Request const& request) {
OnnxHiddenStateScoringContextRef history(dynamic_cast<const OnnxHiddenStateScoringContext*>(request.context.get()));
OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast<const OnnxHiddenStateScoringContext*>(request.context.get()));

bool updateState = false;
switch (request.transitionType) {
Expand All @@ -224,23 +224,19 @@ Core::Ref<const ScoringContext> StatefulOnnxLabelScorer::extendedScoringContext(
error() << "Unknown transition type " << request.transitionType;
}

// If history is not going to be modified, return the original one
// If scoring context is not going to be modified, return the original one
if (not updateState) {
return request.context;
}

std::vector<LabelIndex> newLabelSeq(history->labelSeq);
std::vector<LabelIndex> newLabelSeq(scoringContext->labelSeq);
newLabelSeq.push_back(request.nextToken);

OnnxHiddenStateRef newHiddenState;
if (not history->hiddenState) { // Sentinel start-state
newHiddenState = updatedHiddenState(computeInitialHiddenState(), request.nextToken);
}
else {
newHiddenState = updatedHiddenState(history->hiddenState, request.nextToken);
}
// Re-use previous hidden-state but mark that finalization (i.e. hidden-state update) is required
auto newScoringContext = Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), scoringContext->hiddenState));
newScoringContext->requiresFinalize = true;

return Core::ref(new OnnxHiddenStateScoringContext(std::move(newLabelSeq), newHiddenState));
return newScoringContext;
}

void StatefulOnnxLabelScorer::addInput(DataView const& input) {
Expand All @@ -264,41 +260,44 @@ std::optional<LabelScorer::ScoresWithTimes> StatefulOnnxLabelScorer::computeScor
result.scores.reserve(requests.size());

/*
* Identify unique histories that still need session runs
* Identify unique scoring contexts that still need session runs
*/
std::unordered_set<OnnxHiddenStateScoringContextRef, ScoringContextHash, ScoringContextEq> uniqueUncachedHistories;
std::unordered_set<OnnxHiddenStateScoringContextRef, ScoringContextHash, ScoringContextEq> uniqueUncachedScoringContexts;

for (auto& request : requests) {
OnnxHiddenStateScoringContextRef historyPtr(dynamic_cast<const OnnxHiddenStateScoringContext*>(request.context.get()));
if (not scoreCache_.contains(historyPtr)) {
// Group by unique history
uniqueUncachedHistories.emplace(historyPtr);
// We need to finalize all scoring contexts before using them for scoring again.

OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast<const OnnxHiddenStateScoringContext*>(request.context.get()));
finalizeScoringContext(scoringContext);
if (not scoreCache_.contains(scoringContext)) {
// Group by unique scoring context
uniqueUncachedScoringContexts.emplace(scoringContext);
}
}

std::vector<OnnxHiddenStateScoringContextRef> historyBatch;
historyBatch.reserve(std::min(uniqueUncachedHistories.size(), maxBatchSize_));
for (auto history : uniqueUncachedHistories) {
historyBatch.push_back(history);
if (historyBatch.size() == maxBatchSize_) { // Batch is full -> forward now
forwardBatch(historyBatch);
historyBatch.clear();
std::vector<OnnxHiddenStateScoringContextRef> scoringContextBatch;
scoringContextBatch.reserve(std::min(uniqueUncachedScoringContexts.size(), maxBatchSize_));
for (auto scoringContext : uniqueUncachedScoringContexts) {
scoringContextBatch.push_back(scoringContext);
if (scoringContextBatch.size() == maxBatchSize_) { // Batch is full -> forward now
forwardBatch(scoringContextBatch);
scoringContextBatch.clear();
}
}

forwardBatch(historyBatch); // Forward remaining histories
forwardBatch(scoringContextBatch); // Forward remaining scoring contexts

/*
* Assign from cache map to result vector
*/
for (const auto& request : requests) {
OnnxHiddenStateScoringContextRef history(dynamic_cast<const OnnxHiddenStateScoringContext*>(request.context.get()));
OnnxHiddenStateScoringContextRef scoringContext(dynamic_cast<const OnnxHiddenStateScoringContext*>(request.context.get()));

verify(scoreCache_.contains(history));
auto const& scores = scoreCache_.get(history)->get();
verify(scoreCache_.contains(scoringContext));
auto const& scores = scoreCache_.get(scoringContext)->get();

result.scores.push_back(scores.at(request.nextToken));
result.timeframes.push_back(history->labelSeq.size());
result.timeframes.push_back(scoringContext->labelSeq.size());
}

return result;
Expand Down Expand Up @@ -425,8 +424,25 @@ OnnxHiddenStateRef StatefulOnnxLabelScorer::updatedHiddenState(OnnxHiddenStateRe
return newHiddenState;
}

void StatefulOnnxLabelScorer::forwardBatch(std::vector<OnnxHiddenStateScoringContextRef> const& historyBatch) {
if (historyBatch.empty()) {
void StatefulOnnxLabelScorer::finalizeScoringContext(OnnxHiddenStateScoringContextRef const& scoringContext) {
// If this scoring context does not need finalization, don't change it
if (not scoringContext->requiresFinalize) {
return;
}

auto hiddenState = scoringContext->hiddenState;

if (not hiddenState) { // Sentinel start-state
hiddenState = computeInitialHiddenState();
}
verify(not scoringContext->labelSeq.empty());

scoringContext->hiddenState = updatedHiddenState(hiddenState, scoringContext->labelSeq.back());
scoringContext->requiresFinalize = false;
}

void StatefulOnnxLabelScorer::forwardBatch(std::vector<OnnxHiddenStateScoringContextRef> const& scoringContextBatch) {
if (scoringContextBatch.empty()) {
return;
}

Expand All @@ -439,11 +455,11 @@ void StatefulOnnxLabelScorer::forwardBatch(std::vector<OnnxHiddenStateScoringCon
// Collect a vector of individual state values of shape [1, *] and afterwards concatenate
// them to a batched state tensor of shape [B, *]
std::vector<Onnx::Value const*> stateValues;
stateValues.reserve(historyBatch.size());
stateValues.reserve(scoringContextBatch.size());

for (size_t b = 0ul; b < historyBatch.size(); ++b) {
auto history = historyBatch[b];
auto hiddenState = history->hiddenState;
for (size_t b = 0ul; b < scoringContextBatch.size(); ++b) {
auto scoringContext = scoringContextBatch[b];
auto hiddenState = scoringContext->hiddenState;
if (not hiddenState) { // Sentinel hidden-state
hiddenState = computeInitialHiddenState();
}
Expand All @@ -461,10 +477,10 @@ void StatefulOnnxLabelScorer::forwardBatch(std::vector<OnnxHiddenStateScoringCon
/*
* Put resulting scores into cache map
*/
for (size_t b = 0ul; b < historyBatch.size(); ++b) {
for (size_t b = 0ul; b < scoringContextBatch.size(); ++b) {
std::vector<f32> scoreVec;
sessionOutputs.front().get(b, scoreVec);
scoreCache_.put(historyBatch[b], std::move(scoreVec));
scoreCache_.put(scoringContextBatch[b], std::move(scoreVec));
}
}

Expand Down
14 changes: 9 additions & 5 deletions src/Nn/LabelScorer/StatefulOnnxLabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ public:

void reset() override;

// If startLabelIndex is set, forward that through the state updater to obtain the start history
// If startLabelIndex is set, forward that through the state updater to obtain the start ScoringContext
Core::Ref<const ScoringContext> getInitialScoringContext() override;

// Forward hidden-state through state-updater ONNX model
// Append the new token to the label sequence; does not update the hidden-state. This is only done once the scoringContext is used for scoring again.
Core::Ref<const ScoringContext> extendedScoringContext(LabelScorer::Request const& request) override;

// Add a single encoder outputs to buffer
Expand All @@ -87,13 +87,17 @@ protected:
size_t getMinActiveInputIndex(Core::CollapsedVector<ScoringContextRef> const& activeContexts) const override;

private:
// Forward a batch of histories through the ONNX model and put the resulting scores into the score cache
void forwardBatch(std::vector<OnnxHiddenStateScoringContextRef> const& historyBatch);
// Forward a batch of scoringContexts through the ONNX model and put the resulting scores into the score cache
void forwardBatch(std::vector<OnnxHiddenStateScoringContextRef> const& scoringContextBatch);

// Computes new hidden state based on previous hidden state and next token through state-updater call
OnnxHiddenStateRef updatedHiddenState(OnnxHiddenStateRef const& hiddenState, LabelIndex nextToken);

// Replace hidden-state in scoringContext with an updated version that includes the last label
void finalizeScoringContext(OnnxHiddenStateScoringContextRef const& scoringContext);

// Since the hidden-state matrix depends on the encoder time axis, we cannot create properly create hidden-states until all encoder states have been passed.
// So getStartHistory sets the initial hidden-state to a sentinel value (empty Ref) and when other functions such as `extendedHistory` and `getScoresWithTime`
// So getInitialScoringContext sets the initial hidden-state to a sentinel value (empty Ref) and when other functions such as `extendedScoringContext` and `getScoresWithTime`
// encounter this sentinel value they call `computeInitialHiddenState` instead to get a usable hidden-state.
OnnxHiddenStateRef computeInitialHiddenState();

Expand Down