|
18 | 18 |
|
19 | 19 | #include <Core/ReferenceCounting.hh> |
20 | 20 | #include <Mm/Types.hh> |
| 21 | +#include <Onnx/Value.hh> |
21 | 22 | #include <Speech/Types.hh> |
22 | 23 |
|
23 | 24 | namespace Nn { |
@@ -105,6 +106,47 @@ struct SeqStepScoringContext : public ScoringContext { |
105 | 106 |
|
106 | 107 | typedef Core::Ref<const SeqStepScoringContext> SeqStepScoringContextRef; |
107 | 108 |
|
| 109 | +/* |
| 110 | + * Hidden state represented by a dictionary of named ONNX values |
| 111 | + */ |
| 112 | +struct OnnxHiddenState : public Core::ReferenceCounted { |
| 113 | + std::unordered_map<std::string, Onnx::Value> stateValueMap; |
| 114 | + |
| 115 | + OnnxHiddenState() |
| 116 | + : stateValueMap() {} |
| 117 | + |
| 118 | + OnnxHiddenState(std::vector<std::string>&& names, std::vector<Onnx::Value>&& values) { |
| 119 | + verify(names.size() == values.size()); |
| 120 | + stateValueMap.reserve(names.size()); |
| 121 | + for (size_t i = 0ul; i < names.size(); ++i) { |
| 122 | + stateValueMap.emplace(std::move(names[i]), std::move(values[i])); |
| 123 | + } |
| 124 | + } |
| 125 | +}; |
| 126 | + |
| 127 | +typedef Core::Ref<OnnxHiddenState> OnnxHiddenStateRef; |
| 128 | + |
| 129 | +/* |
| 130 | + * Scoring context consisting of a hidden state. |
| 131 | + * Assumes that two hidden states are equal if and only if they were created |
| 132 | + * from the same label history. |
| 133 | + */ |
| 134 | +struct OnnxHiddenStateScoringContext : public ScoringContext { |
| 135 | + std::vector<LabelIndex> labelSeq; // Used for hashing |
| 136 | + OnnxHiddenStateRef hiddenState; |
| 137 | + |
| 138 | + OnnxHiddenStateScoringContext() |
| 139 | + : labelSeq(), hiddenState() {} |
| 140 | + |
| 141 | + OnnxHiddenStateScoringContext(std::vector<LabelIndex> const& labelSeq, OnnxHiddenStateRef state) |
| 142 | + : labelSeq(labelSeq), hiddenState(state) {} |
| 143 | + |
| 144 | + bool isEqual(ScoringContextRef const& other) const; |
| 145 | + size_t hash() const; |
| 146 | +}; |
| 147 | + |
| 148 | +typedef Core::Ref<const OnnxHiddenStateScoringContext> OnnxHiddenStateScoringContextRef; |
| 149 | + |
108 | 150 | } // namespace Nn |
109 | 151 |
|
110 | 152 | #endif // SCORING_CONTEXT_HH |
0 commit comments