Skip to content

Commit e1bbe9e

Browse files
authored
Add Nn::StatefulOnnxLabelScorer (#140)
1 parent a4914f0 commit e1bbe9e

File tree

10 files changed

+896
-1
lines changed

10 files changed

+896
-1
lines changed

src/Nn/LabelScorer/BufferedLabelScorer.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,8 @@ std::optional<DataView> BufferedLabelScorer::getInput(size_t inputIndex) const {
6565
return inputBuffer_[bufferPosition];
6666
}
6767

68+
size_t BufferedLabelScorer::bufferSize() const {
69+
return inputBuffer_.size();
70+
}
71+
6872
} // namespace Nn

src/Nn/LabelScorer/BufferedLabelScorer.hh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ protected:
6363
// - getInput(3) will return None since no fourth input was added yet
6464
std::optional<DataView> getInput(size_t inputIndex) const;
6565

66+
// Get number of currently buffered elements
67+
size_t bufferSize() const;
68+
6669
private:
6770
std::deque<DataView> inputBuffer_; // Buffer that contains all the feature data for the current segment
6871
size_t numDeletedInputs_; // Count deleted inputs in order to address the correct index in inputBuffer_

src/Nn/LabelScorer/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ LIBSPRINTLABELSCORER_O = \
2121
$(OBJDIR)/FixedContextOnnxLabelScorer.o \
2222
$(OBJDIR)/NoContextOnnxLabelScorer.o \
2323
$(OBJDIR)/NoOpLabelScorer.o \
24-
$(OBJDIR)/ScoringContext.o
24+
$(OBJDIR)/ScoringContext.o \
25+
$(OBJDIR)/StatefulOnnxLabelScorer.o
2526

2627
# -----------------------------------------------------------------------------
2728

src/Nn/LabelScorer/ScoringContext.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,32 @@ bool SeqStepScoringContext::isEqual(ScoringContextRef const& other) const {
108108
return true;
109109
}
110110

111+
/*
112+
* =================================
113+
* = OnnxHiddenStateScoringContext =
114+
* =================================
115+
*/
116+
size_t OnnxHiddenStateScoringContext::hash() const {
117+
return Core::MurmurHash3_x64_64(reinterpret_cast<void const*>(labelSeq.data()), labelSeq.size() * sizeof(LabelIndex), 0x78b174eb);
118+
}
119+
120+
bool OnnxHiddenStateScoringContext::isEqual(ScoringContextRef const& other) const {
121+
auto* otherPtr = dynamic_cast<const OnnxHiddenStateScoringContext*>(other.get());
122+
if (otherPtr == nullptr) {
123+
return false;
124+
}
125+
126+
if (labelSeq.size() != otherPtr->labelSeq.size()) {
127+
return false;
128+
}
129+
130+
for (auto it_l = labelSeq.begin(), it_r = otherPtr->labelSeq.begin(); it_l != labelSeq.end(); ++it_l, ++it_r) {
131+
if (*it_l != *it_r) {
132+
return false;
133+
}
134+
}
135+
136+
return true;
137+
}
138+
111139
} // namespace Nn

src/Nn/LabelScorer/ScoringContext.hh

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <Core/ReferenceCounting.hh>
2020
#include <Mm/Types.hh>
21+
#include <Onnx/Value.hh>
2122
#include <Speech/Types.hh>
2223

2324
namespace Nn {
@@ -105,6 +106,47 @@ struct SeqStepScoringContext : public ScoringContext {
105106

106107
typedef Core::Ref<const SeqStepScoringContext> SeqStepScoringContextRef;
107108

109+
/*
110+
* Hidden state represented by a dictionary of named ONNX values
111+
*/
112+
struct OnnxHiddenState : public Core::ReferenceCounted {
113+
std::unordered_map<std::string, Onnx::Value> stateValueMap;
114+
115+
OnnxHiddenState()
116+
: stateValueMap() {}
117+
118+
OnnxHiddenState(std::vector<std::string>&& names, std::vector<Onnx::Value>&& values) {
119+
verify(names.size() == values.size());
120+
stateValueMap.reserve(names.size());
121+
for (size_t i = 0ul; i < names.size(); ++i) {
122+
stateValueMap.emplace(std::move(names[i]), std::move(values[i]));
123+
}
124+
}
125+
};
126+
127+
typedef Core::Ref<OnnxHiddenState> OnnxHiddenStateRef;
128+
129+
/*
130+
* Scoring context consisting of a hidden state.
131+
* Assumes that two hidden states are equal if and only if they were created
132+
* from the same label history.
133+
*/
134+
struct OnnxHiddenStateScoringContext : public ScoringContext {
135+
std::vector<LabelIndex> labelSeq; // Used for hashing
136+
OnnxHiddenStateRef hiddenState;
137+
138+
OnnxHiddenStateScoringContext()
139+
: labelSeq(), hiddenState() {}
140+
141+
OnnxHiddenStateScoringContext(std::vector<LabelIndex> const& labelSeq, OnnxHiddenStateRef state)
142+
: labelSeq(labelSeq), hiddenState(state) {}
143+
144+
bool isEqual(ScoringContextRef const& other) const;
145+
size_t hash() const;
146+
};
147+
148+
typedef Core::Ref<const OnnxHiddenStateScoringContext> OnnxHiddenStateScoringContextRef;
149+
108150
} // namespace Nn
109151

110152
#endif // SCORING_CONTEXT_HH

0 commit comments

Comments
 (0)