diff --git a/src/Am/AcousticModel.hh b/src/Am/AcousticModel.hh index 1929ed24c..57803e435 100644 --- a/src/Am/AcousticModel.hh +++ b/src/Am/AcousticModel.hh @@ -65,7 +65,8 @@ public: public: AcousticModel(const Core::Configuration& c) - : Core::Component(c), Mc::Component(c) {} + : Core::Component(c), + Mc::Component(c) {} virtual ~AcousticModel() {} virtual void load(Mode mode = complete) = 0; diff --git a/src/Am/AcousticModelAdaptor.cc b/src/Am/AcousticModelAdaptor.cc index d7ad9e4df..41b6c2886 100644 --- a/src/Am/AcousticModelAdaptor.cc +++ b/src/Am/AcousticModelAdaptor.cc @@ -17,7 +17,8 @@ using namespace Am; MixtureSetAdaptor::MixtureSetAdaptor(const Core::Configuration& c, Core::Ref toAdapt) - : Component(c), Precursor(c, toAdapt) { + : Component(c), + Precursor(c, toAdapt) { Core::Ref mixtureSet = Mm::Module::instance().readMixtureSet(select("mixture-set")); if (!mixtureSet || !setMixtureSet(mixtureSet)) criticalError("failed to initialize mixture set."); diff --git a/src/Am/AcousticModelAdaptor.hh b/src/Am/AcousticModelAdaptor.hh index 2c268f501..af4ffe55f 100644 --- a/src/Am/AcousticModelAdaptor.hh +++ b/src/Am/AcousticModelAdaptor.hh @@ -29,7 +29,8 @@ protected: public: AcousticModelAdaptor(const Core::Configuration& c, Core::Ref toAdapt) - : Component(c), toAdapt_(toAdapt) { + : Component(c), + toAdapt_(toAdapt) { require(toAdapt_); } virtual ~AcousticModelAdaptor() {} diff --git a/src/Am/ClassicAcousticModel.cc b/src/Am/ClassicAcousticModel.cc index 5963a12f4..11ca85829 100644 --- a/src/Am/ClassicAcousticModel.cc +++ b/src/Am/ClassicAcousticModel.cc @@ -168,6 +168,7 @@ bool ClassicAcousticModel::loadTransitionModel() { transitionModel_->setParentScale(scale()); return true; } + return false; } diff --git a/src/Am/ClassicAcousticModel.hh b/src/Am/ClassicAcousticModel.hh index 3e35b0f74..5cdf8cc02 100644 --- a/src/Am/ClassicAcousticModel.hh +++ b/src/Am/ClassicAcousticModel.hh @@ -136,16 +136,20 @@ public: return (*transitionModel_)[i]; } virtual StateTransitionIndex stateTransitionIndex(AllophoneState phone, s8 subState = 0) const { - if (allophoneStateAlphabet()->isSilence(phone)) + if (allophoneStateAlphabet()->isSilence(phone)) { return TransitionModel::silence; - else + } + else { return transitionModel_->classify(phone, subState); + } } virtual StateTransitionIndex stateTransitionIndex(AllophoneStateIndex e, s8 subState = 0) const { - if (silenceAllophoneStateIndex_ != Fsa::InvalidLabelId and e == silenceAllophoneStateIndex_) + if (silenceAllophoneStateIndex_ != Fsa::InvalidLabelId and e == silenceAllophoneStateIndex_) { return TransitionModel::silence; - else + } + else { return transitionModel_->classifyIndex(e, subState); + } } virtual const ClassicHmmTopology* hmmTopology(Bliss::Phoneme::Id phoneme) const { diff --git a/src/Am/ClassicHmmTopologySet.hh b/src/Am/ClassicHmmTopologySet.hh index 6177045cb..d9725a23c 100644 --- a/src/Am/ClassicHmmTopologySet.hh +++ b/src/Am/ClassicHmmTopologySet.hh @@ -30,7 +30,8 @@ class ClassicHmmTopology { public: ClassicHmmTopology(int nPhoneStates, int nSubStates) - : nPhoneStates_(nPhoneStates), nSubStates_(nSubStates) {} + : nPhoneStates_(nPhoneStates), + nSubStates_(nSubStates) {} int nPhoneStates() const { return nPhoneStates_; } diff --git a/src/Am/ClassicStateModel.cc b/src/Am/ClassicStateModel.cc index 12df00935..17bb0d303 100644 --- a/src/Am/ClassicStateModel.cc +++ b/src/Am/ClassicStateModel.cc @@ -531,8 +531,9 @@ std::string AllophoneAlphabet::symbol(Fsa::LabelId id) const { Fsa::LabelId AllophoneAlphabet::index(const std::string& symbol) const { Fsa::LabelId lid = specialIndex(symbol); - if (lid == Fsa::InvalidLabelId) + if (lid == Fsa::InvalidLabelId) { lid = index(fromString(symbol)); + } return lid; } @@ -798,7 +799,8 @@ struct IndexedAllophoneState { Am::AllophoneState allophoneState; IndexedAllophoneState() {} IndexedAllophoneState(Fsa::LabelId id, Am::AllophoneState allophoneState) - : id(id), allophoneState(allophoneState) {} + : id(id), + allophoneState(allophoneState) {} bool operator<(const IndexedAllophoneState& i) const { return id < i.id; } diff --git a/src/Am/ClassicStateModel.hh b/src/Am/ClassicStateModel.hh index 417b1ebb6..f10fd2163 100644 --- a/src/Am/ClassicStateModel.hh +++ b/src/Am/ClassicStateModel.hh @@ -56,9 +56,11 @@ struct Allophone : public Phonology::Allophone { Allophone() : Precursor() {} Allophone(Bliss::Phoneme::Id phoneme, s16 b) - : Precursor(phoneme), boundary(b) {} + : Precursor(phoneme), + boundary(b) {} Allophone(const Phonology::Allophone& a, s16 b) - : Precursor(a), boundary(b) {} + : Precursor(a), + boundary(b) {} bool operator==(const Allophone& allo) const { return (boundary == allo.boundary) && Precursor::operator==(allo); @@ -66,7 +68,8 @@ struct Allophone : public Phonology::Allophone { struct Hash { Phonology::Allophone::Hash ah; - u32 operator()(const Allophone& a) const { + + u32 operator()(const Allophone& a) const { return ah(a) ^ (u32(a.boundary) << 13); } }; @@ -103,7 +106,8 @@ public: private: struct AllophonePtrHash { Allophone::Hash h; - size_t operator()(const Allophone* const a) const { + + size_t operator()(const Allophone* const a) const { return h(*a); } }; @@ -228,11 +232,13 @@ private: private: AllophoneState(const Allophone* allo, s16 state) - : allo_(allo), state_(state) {} + : allo_(allo), + state_(state) {} public: AllophoneState() - : allo_(0), state_(0) {} + : allo_(0), + state_(0) {} operator const Allophone*() const { require(allo_); @@ -258,13 +264,15 @@ public: require(allo_); return allo_; } + s16 state() const { return state_; } struct Hash { Allophone::Hash ah; - u32 operator()(const AllophoneState& alloState) const { + + u32 operator()(const AllophoneState& alloState) const { return ah(alloState) ^ (u32(alloState.state()) << 21); } }; diff --git a/src/Am/ClassicStateTying.hh b/src/Am/ClassicStateTying.hh index ad16cfea1..c1f8cf21f 100644 --- a/src/Am/ClassicStateTying.hh +++ b/src/Am/ClassicStateTying.hh @@ -38,7 +38,8 @@ private: public: EmissionAlphabet(Mm::MixtureIndex nMixtures = 0) - : nMixtures_(nMixtures), nDisambiguators_(0) {} + : nMixtures_(nMixtures), + nDisambiguators_(0) {} Mm::MixtureIndex nMixtures() const { return nMixtures_; } diff --git a/src/Am/ClassicTransducerBuilder.cc b/src/Am/ClassicTransducerBuilder.cc index 821ea64b2..7431a385b 100644 --- a/src/Am/ClassicTransducerBuilder.cc +++ b/src/Am/ClassicTransducerBuilder.cc @@ -31,7 +31,8 @@ class ClassicTransducerBuilder::Statistics { bool isCoarticulated; PhoneBoundaryStateSetDescriptor( const PhoneBoundaryStateDescriptor& pbsd) - : phoneBoundaryFlag(pbsd.flag), isCoarticulated(pbsd.isCoarticulated()) { + : phoneBoundaryFlag(pbsd.flag), + isCoarticulated(pbsd.isCoarticulated()) { } Core::XmlOpen xmlAttributes(const Core::XmlOpen&) const; @@ -140,7 +141,10 @@ static const Core::ParameterBool paramFixAllophoneContextAtWordBoundaries( ClassicTransducerBuilder::ClassicTransducerBuilder(Core::Ref< const ClassicAcousticModel> model) - : TransducerBuilder(), model_(model), silencesAndNoises_(0), allophoneSuffixes_(2500, AllophoneSuffix::Hash(this), AllophoneSuffix::Equality(this)) { + : TransducerBuilder(), + model_(model), + silencesAndNoises_(0), + allophoneSuffixes_(2500, AllophoneSuffix::Hash(this), AllophoneSuffix::Equality(this)) { require(model); allophones_ = model_->allophoneAlphabet(); allophoneList_ = &model_->allophoneAlphabet()->allophones(); @@ -649,8 +653,7 @@ Fsa::ConstAutomatonRef ClassicTransducerBuilder::applyTransitionModel( const Allophone* silenceAllophone = allophones_->allophone(Allophone( Phonology::Allophone(model_->silence_), Allophone::isInitialPhone | Allophone::isFinalPhone)); - AllophoneState silenceState = allophoneStates_->allophoneState( - silenceAllophone, 0); + AllophoneState silenceState = allophoneStates_->allophoneState(silenceAllophone, 0); Fsa::LabelId silenceLabel = Fsa::InvalidLabelId; if (ff->inputAlphabet() == allophoneStates_) @@ -921,15 +924,21 @@ struct PhoneContext : public Phones { bool boundary_; Fsa::LabelId disambiguator_; PhoneContext(size_t n, bool boundary = false, Fsa::LabelId disambiguator = Fsa::Epsilon) - : Phones(n, Bliss::Phoneme::term), boundary_(boundary), disambiguator_(disambiguator) { + : Phones(n, Bliss::Phoneme::term), + boundary_(boundary), + disambiguator_(disambiguator) { } PhoneContext(const Phones& phones, bool boundary = false, Fsa::LabelId disambiguator = Fsa::Epsilon) - : Phones(phones), boundary_(boundary), disambiguator_(disambiguator) { + : Phones(phones), + boundary_(boundary), + disambiguator_(disambiguator) { } PhoneContext(Phones::const_iterator pBegin, Phones::const_iterator pEnd, bool boundary = false, Fsa::LabelId disambiguator = Fsa::Epsilon) - : Phones(pBegin, pEnd), boundary_(boundary), disambiguator_(disambiguator) { + : Phones(pBegin, pEnd), + boundary_(boundary), + disambiguator_(disambiguator) { } }; @@ -949,7 +958,9 @@ struct NoCoartBoundary { Fsa::LabelId input_, phoneDisambiguator_; NoCoartBoundary(Fsa::State* state, Fsa::LabelId input, Fsa::LabelId phoneDisambiguator) - : state_(state), input_(input), phoneDisambiguator_(phoneDisambiguator) { + : state_(state), + input_(input), + phoneDisambiguator_(phoneDisambiguator) { } }; @@ -1005,10 +1016,10 @@ Fsa::ConstAutomatonRef ClassicTransducerBuilder::createMinimizedContextDependenc Fsa::Hash stateMap; PhoneContext pc(maxHistory + maxFuture, true); - Fsa::State* initial = new Fsa::State(stateMap.insert(pc), - model_->isAcrossWordModelEnabled() ? Fsa::StateTagFinal - : Fsa::StateTagNone, - semiring->one()); + + Fsa::State* initial = new Fsa::State(stateMap.insert(pc), + model_->isAcrossWordModelEnabled() ? Fsa::StateTagFinal : Fsa::StateTagNone, + semiring->one()); _c->setState(initial); _c->setInitialStateId(initial->id()); if (model_->isAcrossWordModelEnabled()) { diff --git a/src/Am/ClassicTransducerBuilder.hh b/src/Am/ClassicTransducerBuilder.hh index c8e4a4af7..5a9898715 100644 --- a/src/Am/ClassicTransducerBuilder.hh +++ b/src/Am/ClassicTransducerBuilder.hh @@ -76,21 +76,27 @@ private: // internal struct PhoneBoundaryStateDescriptor { Phonology::Context context; u8 flag; // PhoneBoundaryFlags - bool isWordStart() const { + + bool isWordStart() const { return (flag & wordStart); } + bool isWordEnd() const { return (flag & wordEnd); } + bool isCoarticulated() const { return (!context.history.empty() && !context.future.empty()); } + bool operator==(const PhoneBoundaryStateDescriptor& r) const { return (context == r.context) && (flag == r.flag); } + struct Hash { Phonology::Context::Hash ch; - u32 operator()(const PhoneBoundaryStateDescriptor& pbsd) const { + + u32 operator()(const PhoneBoundaryStateDescriptor& pbsd) const { return ch(pbsd.context) ^ (u32(pbsd.flag) << 13); } }; @@ -114,8 +120,10 @@ private: // internal struct Hash { ClassicTransducerBuilder* model; + Hash(ClassicTransducerBuilder* mm) : model(mm) {} + size_t operator()(const AllophoneSuffix& as) const { if (!as.hash_) { as.hash_ = model->hashSequence(as); @@ -128,20 +136,21 @@ private: // internal struct Equality { ClassicTransducerBuilder* model; + Equality(ClassicTransducerBuilder* mm) : model(mm) {} + bool operator()(const AllophoneSuffix& l, const AllophoneSuffix& r) const { return model->compareSequences(l, r) == 0; } }; }; + size_t hashSequence(const AllophoneSuffix&); int compareSequences(const AllophoneSuffix&, const AllophoneSuffix&); - typedef std::unordered_map< - AllophoneSuffix, - Fsa::StateId, - AllophoneSuffix::Hash, AllophoneSuffix::Equality> - AllophoneSuffixMap; + + typedef std::unordered_map AllophoneSuffixMap; + AllophoneSuffixMap allophoneSuffixes_; class Statistics; diff --git a/src/Am/StateModel.hh b/src/Am/StateModel.hh index 74dd0de61..24a692488 100644 --- a/src/Am/StateModel.hh +++ b/src/Am/StateModel.hh @@ -34,18 +34,23 @@ private: public: EmissionAlphabet(Mm::MixtureIndex nMixtures = 0) - : nMixtures_(nMixtures), nDisambiguators_(0) {} + : nMixtures_(nMixtures), + nDisambiguators_(0) {} + Mm::MixtureIndex nMixtures() const { return nMixtures_; } + u32 nDisambiguators() const { return nDisambiguators_; } + Fsa::LabelId disambiguator(u32 d) const { if (nDisambiguators_ >= d) nDisambiguators_ = d + 1; return nMixtures_ + d; } + virtual bool isDisambiguator(Fsa::LabelId m) const { return m >= Fsa::LabelId(nMixtures_); } @@ -53,6 +58,7 @@ public: virtual const_iterator end() const { return const_iterator(Fsa::ConstAlphabetRef(this), nMixtures_ + nDisambiguators_); } + virtual std::string symbol(Fsa::LabelId) const; virtual void writeXml(Core::XmlWriter&) const; }; @@ -66,10 +72,13 @@ struct Allophone : public Phonology::Allophone { static const u8 isFinalPhone = 2; Allophone() {} Allophone(const Phonology::Allophone& a, s16 b) - : Phonology::Allophone(a), boundary(b) {} + : Phonology::Allophone(a), + boundary(b) {} + struct Hash { Phonology::Allophone::Hash ah; - u32 operator()(const Allophone& a) const { + + u32 operator()(const Allophone& a) const { return ah(a) ^ (u32(a.boundary) << 13); } }; @@ -80,16 +89,20 @@ struct AllophoneState : public Allophone { s16 state; AllophoneState() {} AllophoneState(const Allophone& a, s16 s) - : Precursor(a), state(s) {} + : Precursor(a), + state(s) {} AllophoneState(const AllophoneState& as) - : Precursor(as), state(as.state) {} + : Precursor(as), + state(as.state) {} bool operator==(const AllophoneState& rhs) const { return Precursor::operator==(rhs) && state == rhs.state; } + struct Hash { Precursor::Hash ah; - u32 operator()(const AllophoneState& a) const { + + u32 operator()(const AllophoneState& a) const { return ah(a) ^ (u32(a.state) << 21); } }; @@ -113,23 +126,30 @@ public: AllophoneStateAlphabet(); explicit AllophoneStateAlphabet(Core::Ref, u32 contextLength = 0, u32 nStates = 0); - void set(Core::Ref, u32 contextLength, u32 nStates); + + void set(Core::Ref, u32 contextLength, u32 nStates); + Core::Ref phonemeInventory() const { return pi_; } + Fsa::LabelId nClasses() const { return nClasses_; } + Fsa::LabelId index(const AllophoneState&) const; AllophoneState allophoneState(Fsa::LabelId) const; - u32 nDisambiguators() const { + + u32 nDisambiguators() const { return nDisambiguators_; } + Fsa::LabelId disambiguator(u32 d) const { if (nDisambiguators_ >= d) nDisambiguators_ = d + 1; return nClasses_ + 1 + d; } + virtual bool isDisambiguator(Fsa::LabelId m) const { return m > Fsa::LabelId(nClasses_); } @@ -137,9 +157,11 @@ public: virtual const_iterator end() const { return const_iterator(Fsa::ConstAlphabetRef(this), nClasses_ + 1 + nDisambiguators_); } + virtual Fsa::LabelId next(Fsa::LabelId id) const { return ++id; } + virtual std::string symbol(Fsa::LabelId) const; virtual void writeXml(Core::XmlWriter& os) const; }; @@ -148,9 +170,11 @@ class EmissionToPhonemeTransducer : public Fsa::StaticAutomaton { public: EmissionToPhonemeTransducer() {} EmissionToPhonemeTransducer(u32 nMixtures, Core::Ref); + const EmissionAlphabet* emissionAlphabet() const { return dynamic_cast(getInputAlphabet().get()); } + const Bliss::PhonemeAlphabet* phonemeAlphabet() const { return dynamic_cast(getOutputAlphabet().get()); } @@ -160,9 +184,11 @@ class AllophoneStateToPhonemeTransducer : public Fsa::StaticAutomaton { public: AllophoneStateToPhonemeTransducer() {} AllophoneStateToPhonemeTransducer(Core::Ref); + const AllophoneStateAlphabet* allophoneStateAlphabet() const { return dynamic_cast(getInputAlphabet().get()); } + const Bliss::PhonemeAlphabet* phonemeAlphabet() const { return dynamic_cast(getOutputAlphabet().get()); } @@ -178,14 +204,19 @@ protected: public: StateTying(const Core::Configuration& c, const AllophoneStateAlphabet& a) - : Component(c), alphabet_(a) {} + : Component(c), + alphabet_(a) {} virtual ~StateTying() {} - virtual void getDependencies(Core::DependencySet&) const {} + + virtual void getDependencies(Core::DependencySet&) const {} + const AllophoneStateAlphabet& allophoneStateAlphabet() const { return alphabet_; } + virtual Mm::MixtureIndex nClasses() const = 0; virtual Mm::MixtureIndex classify(const AllophoneState& as) const = 0; + virtual Mm::MixtureIndex classifyIndex(AllophoneStateAlphabet::Index index) const { CacheMap::const_iterator iter = classifyIndexCache_.find(index); if (iter == classifyIndexCache_.end()) { @@ -202,14 +233,17 @@ public: class NoStateTying : public StateTying { public: NoStateTying(const Core::Configuration& c, const AllophoneStateAlphabet& a) - : Component(c), StateTying(c, a) {} + : Component(c), + StateTying(c, a) {} virtual Mm::MixtureIndex nClasses() const { return alphabet_.nClasses(); } + virtual Mm::MixtureIndex classifyIndex(AllophoneStateAlphabet::Index i) const { return i; } + virtual Mm::MixtureIndex classify(const AllophoneState& as) const { return alphabet_.index(as); } diff --git a/src/Am/TransitionModel.cc b/src/Am/TransitionModel.cc index b079d4300..70c372a85 100644 --- a/src/Am/TransitionModel.cc +++ b/src/Am/TransitionModel.cc @@ -139,7 +139,10 @@ struct ApplicatorState { } ApplicatorState(Mask m, Fsa::LabelId e, StateType t, Fsa::StateId r) - : mask(m), emission(e), weights(t), right(r) {} + : mask(m), + emission(e), + weights(t), + right(r) {} struct Equality { bool operator()(ApplicatorState const& ll, ApplicatorState const& rr) const { @@ -248,7 +251,8 @@ class AbstractApplicator : public Applicator { } StateDegrees(Fsa::ConstAutomatonRef ff, Fsa::ConstAlphabetRef aa) - : Fsa::DfsState(ff), alphabet_(aa) {} + : Fsa::DfsState(ff), + alphabet_(aa) {} const Degree& operator[](Fsa::StateId ii) const { return degrees_[ii]; @@ -258,7 +262,8 @@ class AbstractApplicator : public Applicator { struct StackItem : AppState { Fsa::StateRef result; StackItem(AppState const& state, Fsa::StateRef _result) - : AppState(state), result(_result) {} + : AppState(state), + result(_result) {} }; typedef std::stack StateStack; @@ -954,7 +959,9 @@ Am::TransitionModel* Am::TransitionModel::createTransitionModel(const Core::Conf // =========================================================================== ScaledTransitionModel::ScaledTransitionModel(const Core::Configuration& c, ClassicStateModelRef stateModel) - : Core::Component(c), Mc::Component(c), transitionModel_(0) { + : Core::Component(c), + Mc::Component(c), + transitionModel_(0) { transitionModel_ = TransitionModel::createTransitionModel(c, stateModel); } diff --git a/src/Am/TransitionModel.hh b/src/Am/TransitionModel.hh index 24608a33a..c1327cb5a 100644 --- a/src/Am/TransitionModel.hh +++ b/src/Am/TransitionModel.hh @@ -94,7 +94,8 @@ public: protected: std::vector transitionModels_; - void dump(Core::XmlWriter&) const; + + void dump(Core::XmlWriter&) const; public: TransitionModel(const Core::Configuration&); @@ -107,9 +108,11 @@ public: * @return is false if correction was necessary. */ bool correct(); - u32 nModels() const { + + u32 nModels() const { return transitionModels_.size(); } + const StateTransitionModel* operator[](int i) const { require_(0 <= i && i < (int)transitionModels_.size() && transitionModels_[i] != 0); return transitionModels_[i]; diff --git a/src/Audio/Ffmpeg.cc b/src/Audio/Ffmpeg.cc index 13d1f8811..0eca0e0fe 100644 --- a/src/Audio/Ffmpeg.cc +++ b/src/Audio/Ffmpeg.cc @@ -117,7 +117,12 @@ struct FfmpegInputNode::Internal { }; FfmpegInputNode::FfmpegInputNode(const Core::Configuration& c) - : Core::Component(c), Node(c), Precursor(c), internal_(new Internal()), buffer_(nullptr), resampleRate_(paramResampleRate(c)) { + : Core::Component(c), + Node(c), + Precursor(c), + internal_(new Internal()), + buffer_(nullptr), + resampleRate_(paramResampleRate(c)) { std::call_once(FfmpegInputNode::ffmpeg_initialized, av_register_all); } diff --git a/src/Audio/Flac.cc b/src/Audio/Flac.cc index 1d0676a04..779fe8607 100644 --- a/src/Audio/Flac.cc +++ b/src/Audio/Flac.cc @@ -21,7 +21,10 @@ using namespace Audio; // =========================================================================== FlacInputNode::FlacInputNode(const Core::Configuration& c) - : Core::Component(c), Node(c), SourceNode(c), fd_(0) {} + : Core::Component(c), + Node(c), + SourceNode(c), + fd_(0) {} bool FlacInputNode::openFile_() { fd_ = new FlacDecoder(); diff --git a/src/Audio/Node.cc b/src/Audio/Node.cc index 89e117f40..9f42e2752 100644 --- a/src/Audio/Node.cc +++ b/src/Audio/Node.cc @@ -30,7 +30,12 @@ const Core::ParameterString Node::paramFilename( "file", "name of audio file"); Node::Node(const Core::Configuration& c) - : Core::Component(c), Flow::Node(c), sampleRate_(0), sampleSize_(0), trackCount_(0), sampleCount_(0) { + : Core::Component(c), + Flow::Node(c), + sampleRate_(0), + sampleSize_(0), + trackCount_(0), + sampleCount_(0) { filename_ = paramFilename(config); } @@ -58,7 +63,8 @@ const Core::ParameterFloat SourceNode::paramEndTime( Core::Type::max, 0.0); SourceNode::SourceNode(const Core::Configuration& c) - : Core::Component(c), Node(c) { + : Core::Component(c), + Node(c) { addOutput(0); blockSize_ = paramBlockSize(config); startTime_ = paramStartTime(config); @@ -183,7 +189,8 @@ bool SourceNode::work(Flow::PortId out) { // =========================================================================== Audio::SinkNode::SinkNode(const Core::Configuration& c) - : Core::Component(c), Node(c) { + : Core::Component(c), + Node(c) { addInput(0); addOutput(0); } diff --git a/src/Audio/Node.hh b/src/Audio/Node.hh index 0f53d5bc8..ae59c95e5 100644 --- a/src/Audio/Node.hh +++ b/src/Audio/Node.hh @@ -52,13 +52,15 @@ public: } virtual bool isFileOpen() const; - bool openFile() { + + bool openFile() { require(!isFileOpen()); bool result = openFile_(); sampleCount_ = 0; ensure(result == isFileOpen()); return result; } + void closeFile() { require(isFileOpen()); closeFile_(); @@ -78,6 +80,7 @@ public: if (isFileOpen()) closeFile(); } + virtual bool setParameter(const std::string& name, const std::string& value); }; @@ -124,11 +127,14 @@ protected: /** @return is sample position of startTime_. */ SampleCount getStartSample() const; + /** @return is sample position of endTime_. */ SampleCount getEndSample() const; - void setBlockSize(u32 blockSize) { + + void setBlockSize(u32 blockSize) { blockSize_ = blockSize; } + u32 blockSize() const { return blockSize_; } diff --git a/src/Audio/Oss.cc b/src/Audio/Oss.cc index 87edaf011..b9fbe445d 100644 --- a/src/Audio/Oss.cc +++ b/src/Audio/Oss.cc @@ -29,7 +29,9 @@ const Core::ParameterString Audio::OpenSoundSystemDevice::paramDevice( "device", "name of audio device", "/dev/dsp"); OpenSoundSystemDevice::OpenSoundSystemDevice(const Core::Configuration& c) - : Core::Component(c), Node(c), fd_(-1) { + : Core::Component(c), + Node(c), + fd_(-1) { filename_ = paramDevice(config); } @@ -143,7 +145,10 @@ void OpenSoundSystemDevice::setTrackCount(u8 _trackCount) { // =========================================================================== OpenSoundSystemInputNode::OpenSoundSystemInputNode(const Core::Configuration& c) - : Core::Component(c), Node(c), RawSourceNode(c), OpenSoundSystemDevice(c) {} + : Core::Component(c), + Node(c), + RawSourceNode(c), + OpenSoundSystemDevice(c) {} bool OpenSoundSystemInputNode::openFile_() { return openDevice(); @@ -183,7 +188,10 @@ u32 OpenSoundSystemInputNode::read(u32 nSamples, Flow::Timestamp*& d) { // =========================================================================== OpenSoundSystemOutputNode::OpenSoundSystemOutputNode(const Core::Configuration& c) - : Core::Component(c), Node(c), SinkNode(c), OpenSoundSystemDevice(c) {} + : Core::Component(c), + Node(c), + SinkNode(c), + OpenSoundSystemDevice(c) {} bool OpenSoundSystemOutputNode::openFile_() { return openDevice(); diff --git a/src/Audio/Oss.hh b/src/Audio/Oss.hh index 380e99348..2dfcccef9 100644 --- a/src/Audio/Oss.hh +++ b/src/Audio/Oss.hh @@ -45,13 +45,17 @@ protected: public: static const Core::ParameterString paramDevice; - static std::string filterName() { + + static std::string filterName() { return "audio-input-device-oss"; } + OpenSoundSystemDevice(const Core::Configuration&); + virtual bool isFileOpen() const { return isDeviceOpen(); } + virtual void setSampleRate(Flow::Time _sampleRate); virtual void setSampleSize(u8 _sampleSize); virtual void setTrackCount(u8 _trackCount); diff --git a/src/Audio/Raw.cc b/src/Audio/Raw.cc index eaa4e6a20..33633a9c6 100644 --- a/src/Audio/Raw.cc +++ b/src/Audio/Raw.cc @@ -31,7 +31,9 @@ const Core::ParameterInt RawSourceNode::paramTracks( "track-count", "number of tracks", 1, 1); RawSourceNode::RawSourceNode(const Core::Configuration& c) - : Core::Component(c), Node(c), SourceNode(c) { + : Core::Component(c), + Node(c), + SourceNode(c) { sampleRate_ = paramSampleRate(config); sampleSize_ = paramSampleSize(config); trackCount_ = paramTracks(config); @@ -57,7 +59,9 @@ const Core::ParameterInt RawFileInputNode::paramOffset( "offset", "number of bytes to skip at start of file", 0, 0); RawFileInputNode::RawFileInputNode(const Core::Configuration& c) - : Core::Component(c), Node(c), RawSourceNode(c) { + : Core::Component(c), + Node(c), + RawSourceNode(c) { offset_ = paramOffset(c); } diff --git a/src/Audio/Wav.hh b/src/Audio/Wav.hh index cd0034451..6e722c6be 100644 --- a/src/Audio/Wav.hh +++ b/src/Audio/Wav.hh @@ -67,6 +67,7 @@ public: return "audio-output-file-wav"; } WavOutputNode(const Core::Configuration& c); + virtual ~WavOutputNode() { if (isFileOpen()) closeFile_(); diff --git a/src/Bliss/CorpusDescription.cc b/src/Bliss/CorpusDescription.cc index 84f270539..8fd343b92 100644 --- a/src/Bliss/CorpusDescription.cc +++ b/src/Bliss/CorpusDescription.cc @@ -335,7 +335,9 @@ void CorpusDescription::SegmentPartitionVisitorAdaptor::loadSegmentList(const st // --------------------------------------------------------------------------- ProgressReportingVisitorAdaptor::ProgressReportingVisitorAdaptor(Core::XmlChannel& ch, bool reportOrth) - : visitor_(0), channel_(ch), reportSegmentOrth_(reportOrth) {} + : visitor_(0), + channel_(ch), + reportSegmentOrth_(reportOrth) {} void ProgressReportingVisitorAdaptor::enterCorpus(Corpus* c) { channel_ << Core::XmlOpen((c->level()) ? "subcorpus" : "corpus") + Core::XmlAttribute("name", c->name()) + Core::XmlAttribute("full-name", c->fullName()); @@ -432,7 +434,9 @@ class CorpusDescription::ProgressIndicationVisitorAdaptor : public CorpusVisitor public: ProgressIndicationVisitorAdaptor() - : nSegments_(0), visitor_(0), pi_("traversing corpus", "segments") {} + : nSegments_(0), + visitor_(0), + pi_("traversing corpus", "segments") {} void setVisitor(CorpusVisitor* v) { visitor_ = v; } diff --git a/src/Bliss/CorpusDescription.hh b/src/Bliss/CorpusDescription.hh index 6c853542c..9407f6a06 100644 --- a/src/Bliss/CorpusDescription.hh +++ b/src/Bliss/CorpusDescription.hh @@ -93,10 +93,12 @@ class Speaker : public NamedCorpusEntity { friend class SpeakerDescriptionElement; public: - enum Gender { unknown, - male, - female, - nGenders }; + enum Gender { + unknown, + male, + female, + nGenders + }; static const char* genderId[nGenders]; private: @@ -320,9 +322,11 @@ class Segment : public ParentEntity { friend class CorpusDescriptionParser; public: - enum Type { typeSpeech, - typeOther, - nTypes }; + enum Type { + typeSpeech, + typeOther, + nTypes + }; static const char* typeId[nTypes]; private: @@ -563,9 +567,11 @@ private: ProgressReportingVisitorAdaptor* reporter_; - enum ProgressIndcationMode { noProgress, - localProgress, - globalProgress }; + enum ProgressIndcationMode { + noProgress, + localProgress, + globalProgress + }; static const Core::Choice progressIndicationChoice; static const Core::ParameterChoice paramProgressIndication; ProgressIndcationMode progressIndicationMode_; diff --git a/src/Bliss/CorpusParser.cc b/src/Bliss/CorpusParser.cc index 16c043abe..a9a2e6ec2 100644 --- a/src/Bliss/CorpusParser.cc +++ b/src/Bliss/CorpusParser.cc @@ -298,7 +298,8 @@ void CorpusDescriptionParser::initSchema() { } CorpusDescriptionParser::CorpusDescriptionParser(const Configuration& c) - : XmlSchemaParser(c), progressIndicator_(paramProgress(c) ? new Core::ProgressIndicator("CorpusDescriptionParser", "segments") : 0) { + : XmlSchemaParser(c), + progressIndicator_(paramProgress(c) ? new Core::ProgressIndicator("CorpusDescriptionParser", "segments") : 0) { initSchema(); isSubParser_ = false; @@ -312,7 +313,8 @@ CorpusDescriptionParser::CorpusDescriptionParser(const Configuration& c) } CorpusDescriptionParser::CorpusDescriptionParser(const Configuration& c, Corpus* _corpus) - : XmlSchemaParser(c), progressIndicator_(paramProgress(c) ? new Core::ProgressIndicator("CorpusDescriptionParser", "segments") : 0) { + : XmlSchemaParser(c), + progressIndicator_(paramProgress(c) ? new Core::ProgressIndicator("CorpusDescriptionParser", "segments") : 0) { initSchema(); isSubParser_ = true; diff --git a/src/Bliss/CorpusStatistics.hh b/src/Bliss/CorpusStatistics.hh index 60725d3e2..0b94e4264 100644 --- a/src/Bliss/CorpusStatistics.hh +++ b/src/Bliss/CorpusStatistics.hh @@ -61,7 +61,8 @@ private: unsigned int nSegments; Time totalDuration; SpeakerStatistics() - : nSegments(0), totalDuration(0.0) {} + : nSegments(0), + totalDuration(0.0) {} }; void writeSpeakerStatistics(const SpeakerStatistics&, Core::XmlWriter&) const; typedef Core::StringHashMap SpeakerStatisticsMap; @@ -86,7 +87,8 @@ private: unsigned int nSegments; Time totalDuration; ConditionStatistics() - : nSegments(0), totalDuration(0.0) {} + : nSegments(0), + totalDuration(0.0) {} }; void writeConditionStatistics(const ConditionStatistics&, Core::XmlWriter&) const; typedef Core::StringHashMap ConditionStatisticsMap; diff --git a/src/Bliss/EditDistance.cc b/src/Bliss/EditDistance.cc index bafdeb33a..9a145b181 100644 --- a/src/Bliss/EditDistance.cc +++ b/src/Bliss/EditDistance.cc @@ -127,7 +127,8 @@ struct EditDistance::Trace : public Core::ReferenceCounted, struct EditDistance::State { Fsa::StateId a, b; - bool operator==(const State& rhs) const { + + bool operator==(const State& rhs) const { return a == rhs.a && b == rhs.b; } @@ -445,10 +446,18 @@ void ErrorStatistic::clear() { } ErrorStatistic::ErrorStatistic() - : nLeftTokens_(0), nRightTokens_(0), nInsertions_(0), nDeletions_(0), nSubstitutions_(0) {} + : nLeftTokens_(0), + nRightTokens_(0), + nInsertions_(0), + nDeletions_(0), + nSubstitutions_(0) {} ErrorStatistic::ErrorStatistic(const EditDistance::Alignment& a) - : nLeftTokens_(0), nRightTokens_(0), nInsertions_(0), nDeletions_(0), nSubstitutions_(0) { + : nLeftTokens_(0), + nRightTokens_(0), + nInsertions_(0), + nDeletions_(0), + nSubstitutions_(0) { *this += a; } diff --git a/src/Bliss/EditDistance.hh b/src/Bliss/EditDistance.hh index bba4f4d8d..18f24ac2e 100644 --- a/src/Bliss/EditDistance.hh +++ b/src/Bliss/EditDistance.hh @@ -62,23 +62,29 @@ namespace Bliss { class EditDistance : public Core::Component { public: - typedef enum { Deletion, - Insertion, - Substitution, - Correct, - Empty } EditOperation; + typedef enum { + Deletion, + Insertion, + Substitution, + Correct, + Empty + } EditOperation; typedef f32 Score; typedef u32 Cost; typedef std::pair EditCost; - typedef enum { FormatBliss, - FormatNist } AlignmentFormat; + typedef enum { + FormatBliss, + FormatNist + } AlignmentFormat; struct AlignmentItem { const Token * a, *b; EditOperation op; AlignmentItem(const Bliss::Token* _a, const Bliss::Token* _b, EditOperation _op) - : a(_a), b(_b), op(_op) {} + : a(_a), + b(_b), + op(_op) {} void write(std::ostream& o) const; friend std::ostream& operator<<(std::ostream& o, const AlignmentItem& a) { @@ -99,7 +105,9 @@ public: public: Alignment(AlignmentFormat format = FormatBliss) - : format_(format), score(0), cost(0) {} + : format_(format), + score(0), + cost(0) {} void clear(); void write(Core::XmlWriter& xml) const; }; @@ -155,21 +163,27 @@ public: void operator+=(const EditDistance::Alignment&); void operator+=(const ErrorStatistic&); void write(Core::XmlWriter&) const; - u32 nLeftTokens() const { + + u32 nLeftTokens() const { return nLeftTokens_; } + u32 nRightTokens() const { return nLeftTokens_ - nDeletions_ + nInsertions_; } + u32 nInsertions() const { return nInsertions_; } + u32 nDeletions() const { return nDeletions_; } + u32 nSubstitutions() const { return nSubstitutions_; } + u32 nErrors() const { return nInsertions_ + nDeletions_ + nSubstitutions_; } diff --git a/src/Bliss/Fsa.cc b/src/Bliss/Fsa.cc index 02a0c19b0..c2cc83ae6 100644 --- a/src/Bliss/Fsa.cc +++ b/src/Bliss/Fsa.cc @@ -19,10 +19,14 @@ using namespace Bliss; TokenAlphabet::TokenAlphabet(const TokenInventory& ti) - : lexicon_(), tokens_(ti), nDisambiguators_(0) {} + : lexicon_(), + tokens_(ti), + nDisambiguators_(0) {} TokenAlphabet::TokenAlphabet(LexiconRef l, const TokenInventory& ti) - : lexicon_(l), tokens_(ti), nDisambiguators_(0) {} + : lexicon_(l), + tokens_(ti), + nDisambiguators_(0) {} TokenAlphabet::~TokenAlphabet() {} diff --git a/src/Bliss/Fsa.hh b/src/Bliss/Fsa.hh index 9b271e81e..86de1e1e7 100644 --- a/src/Bliss/Fsa.hh +++ b/src/Bliss/Fsa.hh @@ -116,7 +116,8 @@ private: friend class Lexicon; mutable u32 nDisambiguators_; LemmaPronunciationAlphabet(LexiconRef l) - : lexicon_(l), nDisambiguators_(0) {} + : lexicon_(l), + nDisambiguators_(0) {} public: Fsa::LabelId index(const LemmaPronunciation* l) const { diff --git a/src/Bliss/Lexicon.cc b/src/Bliss/Lexicon.cc index ef091a569..fd353bfd8 100644 --- a/src/Bliss/Lexicon.cc +++ b/src/Bliss/Lexicon.cc @@ -56,7 +56,8 @@ bool Lemma::hasPronunciation(const Pronunciation* pron) const { // =========================================================================== Pronunciation::Pronunciation(const Phoneme::Id* _phonemes) - : lemmas_(0), phonemes_(_phonemes) {} + : lemmas_(0), + phonemes_(_phonemes) {} Pronunciation::~Pronunciation() {} @@ -128,8 +129,7 @@ struct Lexicon::Internal { Lexicon::Lexicon(const Configuration& c) : Component(c), symbolSequences_(symbols_), - internal_(0) { - internal_ = new Internal; + internal_(new Internal) { } Lexicon::~Lexicon() { @@ -199,9 +199,9 @@ void Lexicon::setOrthographicForms(Lemma* lemma, const std::vector& require(isWhitespaceNormalized(*orth)); const char *cc, *nc; for (cc = nc = orth->c_str(); *cc; cc = nc) { - do + do { ++nc; - while (*nc && utf8::byteType(*nc) == utf8::multiByteTail); + } while (*nc && utf8::byteType(*nc) == utf8::multiByteTail); letter(std::string(cc, nc)); } } @@ -222,6 +222,14 @@ void Lexicon::setDefaultLemmaName(Lemma* lemma) { lemmas_.link(lemma->name(), lemma); } +void Lexicon::setLemmaName(Lemma* lemma, Symbol symbol) { + require(lemma); + require(!lemma->hasName()); + verify(isWhitespaceNormalized(symbol.str())); + lemma->setName(symbol); + lemmas_.link(lemma->name(), lemma); +} + const Letter* Lexicon::letter(const std::string& letter) const { Token* token = letters_[letter.c_str()]; if (!token) { @@ -294,6 +302,11 @@ Core::Status Lexicon::getPronunciation(const std::string& phon, Pronunciation*& return status; } +Pronunciation* Lexicon::getPronunciation(const std::vector& phonemes) { + require(phonemeInventory()); + return getOrCreatePronunciation(phonemes); +} + void Lexicon::addPronunciation(Lemma* lemma, Pronunciation* pron, f32 weight) { require(lemma); require(pron); @@ -350,6 +363,20 @@ void Lexicon::setSyntacticTokenSequence(Lemma* lemma, const std::vectorsetSyntacticTokenSequence(tokenSequence); } +void Lexicon::setSyntacticTokenSequence(Lemma* lemma, const std::vector& synt_ids) { + require(lemma); + synts_.start(); + for (Token::Id id : synt_ids) { + SyntacticToken* token = static_cast(syntacticTokens_[id]); + // SyntacticToken* token = syntacticTokens_[id]; + synts_.grow(token); + token->addLemma(lemma); + } + SyntacticTokenSequence tokenSequence(synts_.currentBegin(), synts_.currentEnd()); + synts_.finish(); + lemma->setSyntacticTokenSequence(tokenSequence); +} + void Lexicon::setDefaultSyntacticToken(Lemma* lemma) { require(lemma); require(lemma->nOrthographicForms()); @@ -381,6 +408,18 @@ void Lexicon::addEvaluationTokenSequence(Lemma* lemma, const std::vectoraddEvaluationTokenSequence(tokenSequence); } +void Lexicon::addEvaluationTokenSequence(Lemma* lemma, const std::vector& ids) { + require(lemma); + evals_.start(); + for (auto id : ids) { + EvaluationToken* token = static_cast(evaluationTokens_[id]); + evals_.grow(token); + } + EvaluationTokenSequence tokenSequence(evals_.currentBegin(), evals_.currentEnd()); + evals_.finish(); + lemma->addEvaluationTokenSequence(tokenSequence); +} + void Lexicon::setDefaultEvaluationToken(Lemma* lemma) { require(lemma); require(lemma->nOrthographicForms()); @@ -392,14 +431,32 @@ void Lexicon::setDefaultEvaluationToken(Lemma* lemma) { } void Lexicon::defineSpecialLemma(const std::string& name, Lemma* lemma) { - require(!specialLemma(name)); require(lemma); - specialLemmas_[name] = lemma; + + specialLemmas_[name].insert(lemma); } +void Lexicon::removeSpecialLemma(const Lemma* lemma) { + for (SpecialLemmaMap::iterator iter = specialLemmas_.begin(); iter != specialLemmas_.end(); ++iter) + if (iter->second.count(lemma) > 0) { + iter->second.erase(lemma); + break; + } +} + +// Note: we only return the first lemma (mostly just one for internal specified special names) const Lemma* Lexicon::specialLemma(const std::string& name) const { - LemmaMap::const_iterator i = specialLemmas_.find(name); - return (i == specialLemmas_.end()) ? 0 : i->second; + SpecialLemmaMap::const_iterator i = specialLemmas_.find(name); + return (i == specialLemmas_.end()) ? 0 : *(i->second.begin()); +} + +std::string Lexicon::getSpecialLemmaName(const Lemma* lemma) const { + for (const auto it : specialLemmas_) { + if (it.second.count(lemma) > 0) + return it.first; + } + + return ""; } void Lexicon::writeXml(Core::XmlWriter& os) const { @@ -410,12 +467,7 @@ void Lexicon::writeXml(Core::XmlWriter& os) const { for (tie(l, l_end) = lemmas(); l != l_end; ++l) { const Lemma* lemma(*l); - std::string specialName = ""; - for (LemmaMap::const_iterator i = specialLemmas_.begin(); - i != specialLemmas_.end(); ++i) { - if (i->second->id() == lemma->id()) - specialName = i->first; - } + std::string specialName = getSpecialLemmaName(lemma); if (specialName.length() == 0) os << Core::XmlOpen("lemma"); else diff --git a/src/Bliss/Lexicon.hh b/src/Bliss/Lexicon.hh index dd2955226..83967d26f 100644 --- a/src/Bliss/Lexicon.hh +++ b/src/Bliss/Lexicon.hh @@ -32,6 +32,7 @@ #include #include #include +#include #include "Phoneme.hh" #include "Symbol.hh" @@ -418,6 +419,10 @@ protected: lemmas_.push_back(lemma); } + void reduceLemma() { + lemmas_.pop_back(); // only size matter, not order + } + public: /** The number of lemmata this token occurs in. */ u32 nLemmas() const { @@ -516,9 +521,11 @@ protected: // lemmas friend class Bliss::LemmaAlphabet; - TokenInventory lemmas_; - typedef Core::StringHashMap LemmaMap; - LemmaMap specialLemmas_; + TokenInventory lemmas_; + + typedef robin_hood::unordered_map> SpecialLemmaMap; + SpecialLemmaMap specialLemmas_; + typedef Core::StringHashMap LemmaMap; friend class Bliss::LemmaPronunciationAlphabet; typedef std::vector LemmaPronunciationList; @@ -575,13 +582,15 @@ public: * Set the unique name of a lemma. */ void setDefaultLemmaName(Lemma* lemma); + void setLemmaName(Lemma* lemma, Symbol symbol); /** * Get a pronunciation for a string representation. * @param phon a string containing a white-space separate list * of phoneme symbols. */ - Core::Status getPronunciation(const std::string& phon, Pronunciation*& out); + Core::Status getPronunciation(const std::string& phon, Pronunciation*& out); + Pronunciation* getPronunciation(const std::vector& phonemes); /** * Add a pronunciation to a lemma. @@ -599,12 +608,14 @@ public: * Set the a syntactic token sequence for a lemma. */ void setSyntacticTokenSequence(Lemma* lemma, const std::vector& synt); + void setSyntacticTokenSequence(Lemma* lemma, const std::vector& synt); void setDefaultSyntacticToken(Lemma* lemma); /** * Set the a evaluation token sequence for a lemma. */ void addEvaluationTokenSequence(Lemma* lemma, const std::vector& eval); + void addEvaluationTokenSequence(Lemma* lemma, const std::vector& ids); void setDefaultEvaluationToken(Lemma* lemma); /** @@ -615,6 +626,8 @@ public: */ void defineSpecialLemma(const std::string& name, Lemma* lemma); + void removeSpecialLemma(const Lemma* lemma); + /** * Load lexicon from XML or txt file. */ @@ -706,6 +719,14 @@ public: */ const Lemma* specialLemma(const std::string& name) const; + /** + * If a lemma is declared as special, return the string specified + * in the "special" attribute (). + * @return empty string if a lemma is not special, or the + * special attribute specified in the lexicon. + */ + std::string getSpecialLemmaName(const Lemma* lemma) const; + Core::Ref lemmaAlphabet() const; // ------------------------------------------------------------------- diff --git a/src/Bliss/Phoneme.cc b/src/Bliss/Phoneme.cc index 21853948f..35e5669c8 100644 --- a/src/Bliss/Phoneme.cc +++ b/src/Bliss/Phoneme.cc @@ -22,7 +22,8 @@ using namespace Bliss; const Phoneme::Id Phoneme::term; Phoneme::Phoneme() - : Token(), isContextDependent_(true) {} + : Token(), + isContextDependent_(true) {} struct PhonemeInventory::Internal { SymbolSet symbols_; diff --git a/src/Bliss/Phonology.cc b/src/Bliss/Phonology.cc index fe42c0b49..499b982dd 100644 --- a/src/Bliss/Phonology.cc +++ b/src/Bliss/Phonology.cc @@ -28,7 +28,7 @@ bool ContextPhonology::SemiContext::empty() const { } else if ((*begin()) == Phoneme::term) { // The 'term' element cannot be followed by not 'term' element. - verify(std::find_if(begin(), end(), std::bind2nd(std::not_equal_to(), Phoneme::term)) == end()); + verify(std::find_if(begin(), end(), std::bind(std::not_equal_to(), std::placeholders::_1, Phoneme::term)) == end()); return true; } return false; diff --git a/src/Bliss/Phonology.hh b/src/Bliss/Phonology.hh index 5e54d28e7..1ffa34723 100644 --- a/src/Bliss/Phonology.hh +++ b/src/Bliss/Phonology.hh @@ -91,7 +91,8 @@ public: SemiContext history, future; Context() {} Context(const SemiContext& h, const SemiContext& f) - : history(h), future(f) {} + : history(h), + future(f) {} bool operator==(const Context& rhs) const { return history == rhs.history && future == rhs.future; } @@ -164,7 +165,8 @@ public: PhonemeInContext() : phoneme_(Phoneme::term){}; PhonemeInContext(Phoneme::Id phoneme, const SemiContext& history = SemiContext(), const SemiContext& future = SemiContext()) - : phoneme_(phoneme), context_(history, future) {} + : phoneme_(phoneme), + context_(history, future) {} bool operator==(const PhonemeInContext& rhs) const { return phoneme_ == rhs.phoneme_ && context_ == rhs.context_; diff --git a/src/Bliss/SegmentOrdering.cc b/src/Bliss/SegmentOrdering.cc index 682c68d5e..f101d0d03 100644 --- a/src/Bliss/SegmentOrdering.cc +++ b/src/Bliss/SegmentOrdering.cc @@ -264,7 +264,10 @@ void SegmentOrderingVisitor::leaveCorpus(Corpus* corpus) { } SegmentOrderingVisitor::CustomCorpusGuide::CustomCorpusGuide(SegmentOrderingVisitor* parent, Corpus* rootCorpus) - : parent_(parent), rootCorpus_(rootCorpus), curCorpus_(rootCorpus_), curRecording_(0) { + : parent_(parent), + rootCorpus_(rootCorpus), + curCorpus_(rootCorpus_), + curRecording_(0) { // Enter the root corpus. parent_->visitor_->enterCorpus(rootCorpus_); } diff --git a/src/Bliss/Symbol.hh b/src/Bliss/Symbol.hh index 70d924341..01a0b7848 100644 --- a/src/Bliss/Symbol.hh +++ b/src/Bliss/Symbol.hh @@ -165,11 +165,13 @@ private: protected: Token(Id _id, Bliss::Symbol _symbol) - : id_(_id), symbol_(_symbol) {} + : id_(_id), + symbol_(_symbol) {} Token(Id _id) : id_(_id) {} Token(Bliss::Symbol _symbol) - : id_(invalidId), symbol_(_symbol) {} + : id_(invalidId), + symbol_(_symbol) {} Token() : id_(invalidId) {} friend class TokenInventory; @@ -240,7 +242,8 @@ public: } typedef Token* const* Iterator; - Iterator begin() const { + + Iterator begin() const { return &(*list_.begin()); } @@ -311,13 +314,16 @@ protected: public: // FIXME SymbolSequence(const Symbol* _begin, const Symbol* _end) - : begin_(_begin), end_(_end) {} + : begin_(_begin), + end_(_end) {} public: SymbolSequence() - : begin_(0), end_(0) {} + : begin_(0), + end_(0) {} SymbolSequence(const SymbolSequence& o) - : begin_(o.begin_), end_(o.end_) {} + : begin_(o.begin_), + end_(o.end_) {} bool valid() const { return begin_ != 0; diff --git a/src/Cart/DecisionTree.hh b/src/Cart/DecisionTree.hh index 4bd325039..7cac5ecb0 100644 --- a/src/Cart/DecisionTree.hh +++ b/src/Cart/DecisionTree.hh @@ -69,7 +69,9 @@ public: public: Question(PropertyMapRef map, const std::string& desc = "") - : map_(map), desc_(desc), index(InvalidQuestionIndex) {} + : map_(map), + desc_(desc), + index(InvalidQuestionIndex) {} virtual ~Question() {} virtual Answer operator()(const Properties& props) const = 0; diff --git a/src/Cart/DecisionTreeTrainer.hh b/src/Cart/DecisionTreeTrainer.hh index 9a19161bf..96a2d702a 100644 --- a/src/Cart/DecisionTreeTrainer.hh +++ b/src/Cart/DecisionTreeTrainer.hh @@ -43,7 +43,8 @@ struct ExamplePtrRange { ExamplePtrList::const_iterator begin; ExamplePtrList::const_iterator end; ExamplePtrRange(ExamplePtrList::const_iterator begin, ExamplePtrList::const_iterator end) - : begin(begin), end(end) {} + : begin(begin), + end(end) {} size_t size() const { return end - begin; } @@ -73,10 +74,10 @@ public: const ExamplePtrRange& leftExamples, const ExamplePtrRange& rightExamples, const Score fatherScore, Score& leftChildScore, Score& rightChildScore) const = 0; - virtual void operator()( - const ExamplePtrRange& examples, - Score& score) const { + + virtual void operator()(const ExamplePtrRange& examples, Score& score) const { Score dummy; + operator()(examples, ExamplePtrRange(ExamplePtrList::const_iterator(), ExamplePtrList::const_iterator()), 0.0, score, dummy); } }; diff --git a/src/Cart/Example.hh b/src/Cart/Example.hh index 16712ebe9..ca5a9423d 100644 --- a/src/Cart/Example.hh +++ b/src/Cart/Example.hh @@ -59,7 +59,8 @@ public: public: row_vector_iterator(f64* begin, size_t size) - : ptr_(begin), size_(size) {} + : ptr_(begin), + size_(size) {} row_iterator begin() { return ptr_; @@ -99,7 +100,8 @@ public: public: const_row_vector_iterator(f64* begin, size_t size) - : ptr_(begin), size_(size) {} + : ptr_(begin), + size_(size) {} const_row_iterator begin() const { return ptr_; @@ -132,7 +134,8 @@ public: public: column_iterator(f64* begin, size_t offset) - : ptr_(begin), offset_(offset) {} + : ptr_(begin), + offset_(offset) {} f64& operator*() { return *ptr_; @@ -155,7 +158,8 @@ public: public: const_column_iterator(f64* begin, size_t offset) - : ptr_(begin), offset_(offset) {} + : ptr_(begin), + offset_(offset) {} f64 operator*() const { return *ptr_; @@ -179,7 +183,9 @@ public: public: column_vector_iterator(f64* begin, size_t size, size_t overall_size) - : ptr_(begin), size_(size), overall_size_(overall_size) {} + : ptr_(begin), + size_(size), + overall_size_(overall_size) {} column_iterator begin() { return column_iterator(ptr_, size_); @@ -220,7 +226,9 @@ public: public: const_column_vector_iterator(f64* begin, size_t size, size_t overall_size) - : ptr_(begin), size_(size), overall_size_(overall_size) {} + : ptr_(begin), + size_(size), + overall_size_(overall_size) {} const_column_iterator begin() const { return const_column_iterator(ptr_, size_); diff --git a/src/Cart/Properties.hh b/src/Cart/Properties.hh index d6abe4131..2011d69ba 100644 --- a/src/Cart/Properties.hh +++ b/src/Cart/Properties.hh @@ -271,7 +271,8 @@ public: struct PtrHashFcn { HashFcn hasher; - size_t operator()(const Properties* const props) const { + + size_t operator()(const Properties* const props) const { return hasher(*props); } }; diff --git a/src/Core/Archive.cc b/src/Core/Archive.cc index 7de1d64aa..07d5905c1 100644 --- a/src/Core/Archive.cc +++ b/src/Core/Archive.cc @@ -27,7 +27,9 @@ using namespace Core; Archive::Archive(const Core::Configuration& config, const std::string& path, AccessMode access) - : Component(config), path_(path), access_(access) { + : Component(config), + path_(path), + access_(access) { } bool Archive::hasFile(const std::string& name) const { @@ -361,7 +363,10 @@ class ArchiveWriterBuffer : public std::streambuf { } // namespace ArchiveWriter::ArchiveWriter(Archive& archive, const std::string& path, bool compress) - : std::ostream(new ArchiveWriterBuffer(buffer_)), archive_(archive), path_(path), compress_(compress) { + : std::ostream(new ArchiveWriterBuffer(buffer_)), + archive_(archive), + path_(path), + compress_(compress) { isOpen_ = true; } diff --git a/src/Core/Archive.hh b/src/Core/Archive.hh index 5bdf4590c..f86f65b24 100644 --- a/src/Core/Archive.hh +++ b/src/Core/Archive.hh @@ -37,10 +37,12 @@ public: /** * Different types of archives handled by this class. **/ - enum Type { TypeUnknown, - TypeDirectory, - TypeFile, - TypeBundle }; + enum Type { + TypeUnknown, + TypeDirectory, + TypeFile, + TypeBundle + }; /** * Different access types to archives. @@ -59,7 +61,8 @@ public: public: Sizes(u32 uncompressed = 0, u32 compressed = 0) - : uncompressed_(uncompressed), compressed_(compressed) {} + : uncompressed_(uncompressed), + compressed_(compressed) {} void setUncompressed(u32 uncompressed) { uncompressed_ = uncompressed; } @@ -83,9 +86,11 @@ public: virtual ~_const_iterator() {} virtual _const_iterator& operator++() = 0; virtual operator bool() const = 0; - const std::string& name() { + + const std::string& name() { return name_; } + const Sizes& sizes() { return sizes_; } diff --git a/src/Core/BinaryStream.cc b/src/Core/BinaryStream.cc index 6cab20a45..9f1e19ce6 100644 --- a/src/Core/BinaryStream.cc +++ b/src/Core/BinaryStream.cc @@ -22,7 +22,8 @@ using namespace Core; // class BinaryOutputStream BinaryStreamIos::BinaryStreamIos(Endianess endianess) - : ios_(0), fstream_(0) { + : ios_(0), + fstream_(0) { ios_ = fstream_ = new std::fstream; setEndianess(endianess); } @@ -36,7 +37,8 @@ BinaryStreamIos::BinaryStreamIos(std::ios& i, Endianess endianess) BinaryStreamIos::BinaryStreamIos(const std::string& fileName, std::ios_base::openmode mode, Endianess endianess) - : ios_(0), fstream_(0) { + : ios_(0), + fstream_(0) { mode |= std::ios::binary; fstream_ = new std::fstream(fileName.c_str(), mode); ios_ = fstream_; diff --git a/src/Core/BinaryStream.hh b/src/Core/BinaryStream.hh index 6b92d6eae..316dc9c9d 100644 --- a/src/Core/BinaryStream.hh +++ b/src/Core/BinaryStream.hh @@ -57,9 +57,11 @@ namespace Core { class BinaryStreamIos { public: typedef std::ios_base::iostate iostate; - enum Endianess { bigEndian, - littleEndian, - nativeByteOrder }; + enum Endianess { + bigEndian, + littleEndian, + nativeByteOrder + }; static const Endianess defaultEndianess = littleEndian; protected: @@ -149,9 +151,11 @@ public: public: BinaryOutputStream(Endianess endianess = defaultEndianess) - : BinaryStreamIos(endianess), os_(fstream_) {} + : BinaryStreamIos(endianess), + os_(fstream_) {} explicit BinaryOutputStream(std::ostream& stream, Endianess endianess = defaultEndianess) - : BinaryStreamIos(stream, endianess), os_(&stream) {} + : BinaryStreamIos(stream, endianess), + os_(&stream) {} explicit BinaryOutputStream(const std::string& fileName, std::ios_base::openmode mode = std::ios::out, Endianess endianess = defaultEndianess) @@ -269,9 +273,11 @@ public: public: BinaryInputStream(Endianess endianess = defaultEndianess) - : BinaryStreamIos(endianess), is_(fstream_) {} + : BinaryStreamIos(endianess), + is_(fstream_) {} BinaryInputStream(std::istream& stream, Endianess endianess = defaultEndianess) - : BinaryStreamIos(stream, endianess), is_(&stream) {} + : BinaryStreamIos(stream, endianess), + is_(&stream) {} BinaryInputStream(const std::string& fileName, std::ios_base::openmode mode = std::ios::in, Endianess endianess = defaultEndianess) diff --git a/src/Core/BinaryTree.hh b/src/Core/BinaryTree.hh index 6f8537f07..86b579c87 100644 --- a/src/Core/BinaryTree.hh +++ b/src/Core/BinaryTree.hh @@ -36,9 +36,11 @@ public: LeafNumber leafNumber; TreeStructureEntry() - : id(0), leafNumber(0){}; + : id(0), + leafNumber(0){}; TreeStructureEntry(Id i, LeafNumber n) - : id(i), leafNumber(n){}; + : id(i), + leafNumber(n){}; void read(Core::BinaryInputStream& i) { i >> id >> leafNumber; @@ -58,7 +60,11 @@ public: u16 leafNumber_; Node() - : left_(0), right_(0), previous_(0), id_(0), leafNumber_(0){}; + : left_(0), + right_(0), + previous_(0), + id_(0), + leafNumber_(0){}; }; typedef std::list LeafList; diff --git a/src/Core/BitStream.hh b/src/Core/BitStream.hh index 4aabe7297..5875fa2fe 100644 --- a/src/Core/BitStream.hh +++ b/src/Core/BitStream.hh @@ -74,7 +74,10 @@ constexpr size_t bitsizeof() { template BitStream::BitStream() - : posg_(0ul), posp_(0ul), size_(0ul), store_() { + : posg_(0ul), + posp_(0ul), + size_(0ul), + store_() { } template diff --git a/src/Core/Channel.cc b/src/Core/Channel.cc index e5ee3591f..c5e5d27c4 100644 --- a/src/Core/Channel.cc +++ b/src/Core/Channel.cc @@ -101,7 +101,11 @@ const Core::ParameterBool Channel::Target::paramAddSprintTags( true); Channel::Target::Target(const Core::Configuration& c, bool isXmlDocument, const std::string& defaultFilename, std::streambuf* defaultStreamBuf) - : Core::Configurable(c), isTty_(false), xml_(*this), isXmlDocument_(isXmlDocument), defaultStreamBuf_(defaultStreamBuf) { + : Core::Configurable(c), + isTty_(false), + xml_(*this), + isXmlDocument_(isXmlDocument), + defaultStreamBuf_(defaultStreamBuf) { open(paramFilename(config, defaultFilename)); setup(); } @@ -133,7 +137,11 @@ class FileStreamBuffer : public std::filebuf { #endif Channel::Target::Target(const Core::Configuration& c, bool isXmlDocument, std::ostream* defaultStream) - : Core::Configurable(c), isTty_(false), xml_(*this), isXmlDocument_(isXmlDocument), defaultStreamBuf_(defaultStream->rdbuf()) { + : Core::Configurable(c), + isTty_(false), + xml_(*this), + isXmlDocument_(isXmlDocument), + defaultStreamBuf_(defaultStream->rdbuf()) { require(defaultStream); std::string filename = paramFilename(config); if (filename.size()) { diff --git a/src/Core/Channel.hh b/src/Core/Channel.hh index f965e83d8..00114df8d 100644 --- a/src/Core/Channel.hh +++ b/src/Core/Channel.hh @@ -106,11 +106,15 @@ namespace Core { class Channel : public std::ostream { public: - enum TargetType { plainTarget, - xmlTarget }; - enum Default { disabled, - standard, - error }; + enum TargetType { + plainTarget, + xmlTarget + }; + enum Default { + disabled, + standard, + error + }; static const Core::ParameterString paramTargets; private: diff --git a/src/Core/Choice.cc b/src/Core/Choice.cc index 2912f75e0..b072c6c4c 100644 --- a/src/Core/Choice.cc +++ b/src/Core/Choice.cc @@ -21,7 +21,8 @@ using namespace Core; const Choice::Value Choice::IllegalValue = -1; const char* const Choice::IllegalIdentifier = ""; -const char* Choice::endMark() { + +const char* Choice::endMark() { return (const char*)0; } diff --git a/src/Core/Choice.hh b/src/Core/Choice.hh index b58e1d178..638e24dc0 100644 --- a/src/Core/Choice.hh +++ b/src/Core/Choice.hh @@ -48,16 +48,21 @@ private: public: Item(const char* ident, const Value value) - : ident_(ident), value_(value) {} + : ident_(ident), + value_(value) {} + const std::string& ident() const { return ident_; } + Value value() const { return value_; } + bool operator<(const Item& item) const { return ident_.compare(item.ident_) < 0; } + bool operator==(const Item& item) const { return (value_ == item.value_) && (ident_ == item.ident_); } @@ -151,9 +156,11 @@ public: * Iterator **/ typedef std::set::const_iterator const_iterator; - const_iterator begin() const { + + const_iterator begin() const { return items_by_ident.begin(); } + const_iterator end() const { return items_by_ident.end(); } diff --git a/src/Core/Component.hh b/src/Core/Component.hh index d8d90ddfc..f50fb0322 100644 --- a/src/Core/Component.hh +++ b/src/Core/Component.hh @@ -202,7 +202,9 @@ public: private: friend class Component; Message(const Component* c, ErrorType type, XmlChannel* ch) - : ostream_(ch), component_(c), type_(type) {} + : ostream_(ch), + component_(c), + type_(type) {} public: operator XmlWriter&() const { @@ -218,7 +220,9 @@ public: __attribute__((format(printf, 2, 3))); Message(const Message& m) - : ostream_(m.ostream_), component_(m.component_), type_(m.type_) { + : ostream_(m.ostream_), + component_(m.component_), + type_(m.type_) { const_cast(m).component_ = 0; } diff --git a/src/Core/CompressedStream.cc b/src/Core/CompressedStream.cc index 2d57d3f10..c531acb94 100644 --- a/src/Core/CompressedStream.cc +++ b/src/Core/CompressedStream.cc @@ -27,10 +27,14 @@ namespace Core { // *************************************************************************** CompressedInputStream::CompressedInputStream() - : std::istream(0), file_buf_(nullptr), buf_(nullptr) {} + : std::istream(0), + file_buf_(nullptr), + buf_(nullptr) {} CompressedInputStream::CompressedInputStream(const std::string& name) - : std::istream(0), file_buf_(nullptr), buf_(nullptr) { + : std::istream(0), + file_buf_(nullptr), + buf_(nullptr) { open(name); } @@ -67,10 +71,14 @@ void CompressedInputStream::close() { // *************************************************************************** CompressedOutputStream::CompressedOutputStream() - : std::ostream(0), file_buf_(nullptr), buf_(nullptr) {} + : std::ostream(0), + file_buf_(nullptr), + buf_(nullptr) {} CompressedOutputStream::CompressedOutputStream(const std::string& name) - : std::ostream(0), file_buf_(nullptr), buf_(nullptr) { + : std::ostream(0), + file_buf_(nullptr), + buf_(nullptr) { open(name); } diff --git a/src/Core/Configuration.cc b/src/Core/Configuration.cc index ac8346d1f..e034b2a7c 100644 --- a/src/Core/Configuration.cc +++ b/src/Core/Configuration.cc @@ -105,7 +105,7 @@ class Configuration::Resource { /** * Determine if the resource matches a configurtion path. - * @param components the coponents of the configuration path + * @param components the components of the configuration path * @return the number of path components matched by the resource, * or -1 of the resource does not match. */ @@ -168,7 +168,6 @@ void Configuration::Resource::writeUsage(XmlWriter& os) const { /** * Central storage place for all resources. */ - class Configuration::ResourceDataBase : public ReferenceCounted { private: std::set resources; @@ -750,8 +749,10 @@ std::vector Configuration::setFromCommandline( const SourceDescriptor* source) { std::string option; std::vector unparsed; - enum { Option, - Argument } state = Option; + enum { + Option, + Argument + } state = Option; if (arguments.size() == 0) return unparsed; diff --git a/src/Core/Debug.cc b/src/Core/Debug.cc index 011b646f8..350f1a565 100644 --- a/src/Core/Debug.cc +++ b/src/Core/Debug.cc @@ -192,7 +192,7 @@ void stackTrace(std::ostream& os, int cutoff) { #else os << "Creating stack trace (innermost first):" << std::endl; - static const size_t maxTraces = 100; + static const size_t maxTraces = 256; void* array[maxTraces]; size_t nTraces = backtrace(array, maxTraces); char** strings = backtrace_symbols(array, nTraces); diff --git a/src/Core/Delegation.hh b/src/Core/Delegation.hh index ef816b0b1..bd0c325a2 100644 --- a/src/Core/Delegation.hh +++ b/src/Core/Delegation.hh @@ -50,7 +50,8 @@ public: TargetMethod tm; Forward(TargetClass& tc, TargetMethod tm) - : tc(tc), tm(tm) {} + : tc(tc), + tm(tm) {} inline void operator()(T t) { (tc.*tm)(t); } diff --git a/src/Core/Dependency.cc b/src/Core/Dependency.cc index 1858e4021..f64d2720e 100644 --- a/src/Core/Dependency.cc +++ b/src/Core/Dependency.cc @@ -64,7 +64,8 @@ class Parser : public XmlSchemaParser { public: Parser(const Configuration& c, DependencySet& dependencySet) - : XmlSchemaParser(c), dependencySet_(dependencySet) { + : XmlSchemaParser(c), + dependencySet_(dependencySet) { XmlRegularElement* dependency = new XmlRegularElementRelay("dependency", this, XmlRegularElementRelay::startHandler(&Self::startDependency), XmlRegularElementRelay::endHandler(&Self::endDependency)); diff --git a/src/Core/Directory.cc b/src/Core/Directory.cc index 5c291ec70..b4e83eca5 100644 --- a/src/Core/Directory.cc +++ b/src/Core/Directory.cc @@ -249,7 +249,8 @@ struct dirent* DirectoryFileIterator::nextEntry() { } DirectoryFileIterator::DirectoryFileIterator(const std::string& path, const Filter* filter) - : end_(false), filter_(filter) { + : end_(false), + filter_(filter) { uid_ = getuid(); gid_ = getgid(); base_ = path; diff --git a/src/Core/Directory.hh b/src/Core/Directory.hh index 36d380815..332380cc1 100644 --- a/src/Core/Directory.hh +++ b/src/Core/Directory.hh @@ -156,7 +156,8 @@ public: public: FileNameFilter(const std::string& name, bool exactMatch = false) - : name_(name), exactMatch_(exactMatch) {} + : name_(name), + exactMatch_(exactMatch) {} virtual bool operator()(const std::string& path, const struct stat64& state, uid_t uid, gid_t gid) const; }; diff --git a/src/Core/DirectoryArchive.hh b/src/Core/DirectoryArchive.hh index 6b71ff117..721e06f65 100644 --- a/src/Core/DirectoryArchive.hh +++ b/src/Core/DirectoryArchive.hh @@ -33,7 +33,8 @@ private: public: _const_iterator(const DirectoryArchive& a) - : iter_(a.path(), &DirectoryFileIterator::fileFilter), a_(a) { + : iter_(a.path(), &DirectoryFileIterator::fileFilter), + a_(a) { if (iter_) if (a_.probe(iter_.path(), iter_.state(), sizes_)) name_ = iter_.path(); diff --git a/src/Core/Extensions.hh b/src/Core/Extensions.hh index 80b995dd3..e5caa0d72 100644 --- a/src/Core/Extensions.hh +++ b/src/Core/Extensions.hh @@ -36,7 +36,7 @@ using __gnu_cxx::select2nd; #else template -struct identity : public std::unary_function { +struct identity { const T& operator()(const T& v) const { return v; } @@ -46,8 +46,7 @@ struct identity : public std::unary_function { }; template -struct binary_compose - : public std::unary_function { +struct binary_compose { typename BinaryFun::result_type operator()(const typename UnaryFun1::argument_type& x) const { return f(g(x), h(x)); } @@ -58,16 +57,14 @@ struct binary_compose }; template -struct select1st - : public std::unary_function { +struct select1st { const typename pair_type::first_type& operator()(const pair_type& v) const { return v.first; } }; template -struct select2nd - : public std::unary_function { +struct select2nd { const typename pair_type::second_type& operator()(const pair_type& v) const { return v.second; } diff --git a/src/Core/FileArchive.cc b/src/Core/FileArchive.cc index 04277fcda..b5b68535a 100644 --- a/src/Core/FileArchive.cc +++ b/src/Core/FileArchive.cc @@ -94,7 +94,9 @@ struct FileArchive::FileInfo { Sizes sizes; FileInfo(const std::string& n, u64 p, const Sizes& s) - : name(n), position(p), sizes(s) {} + : name(n), + position(p), + sizes(s) {} void clear() { name.clear(); @@ -105,7 +107,7 @@ struct FileArchive::FileInfo { return (name.empty() && position == 0); } - struct Empty : public std::unary_function { + struct Empty { bool operator()(const FileInfo& f) const { return f.isEmpty(); } @@ -120,7 +122,8 @@ class FileArchive::_const_iterator : public Archive::_const_iterator { public: _const_iterator(const FileArchive& a) - : iter_(a.files_.begin()), a_(a) { + : iter_(a.files_.begin()), + a_(a) { if (iter_ != a_.files_.end()) { name_ = iter_->name; sizes_ = iter_->sizes; diff --git a/src/Core/Hash.hh b/src/Core/Hash.hh index e84ffab62..63d6c7768 100644 --- a/src/Core/Hash.hh +++ b/src/Core/Hash.hh @@ -77,7 +77,7 @@ struct StringHash { } }; -struct StringEquality : std::binary_function { +struct StringEquality { bool operator()(const char* s, const char* t) const { return (s == t) || (std::strcmp(s, t) == 0); } diff --git a/src/Core/MapParser.hh b/src/Core/MapParser.hh index 21ba5d26a..92d75f579 100644 --- a/src/Core/MapParser.hh +++ b/src/Core/MapParser.hh @@ -103,7 +103,8 @@ class XmlMapDocument : public XmlSchemaParser { private: XmlMapElement* mapElement_; Map& map_; - Map* pseudoCreateMap(const XmlAttributes atts) { + + Map* pseudoCreateMap(const XmlAttributes atts) { return &map_; } diff --git a/src/Core/MappedArchive.cc b/src/Core/MappedArchive.cc index 9445b8448..c88413a6f 100644 --- a/src/Core/MappedArchive.cc +++ b/src/Core/MappedArchive.cc @@ -38,7 +38,9 @@ * */ std::string hostAndProcessId() { - enum { Len = 1000 }; + enum { + Len = 1000 + }; char buf[Len]; gethostname(buf, Len); u64 pid = getpid(); diff --git a/src/Core/MappedArchive.hh b/src/Core/MappedArchive.hh index 3dd58843f..8f990b5ce 100644 --- a/src/Core/MappedArchive.hh +++ b/src/Core/MappedArchive.hh @@ -102,7 +102,8 @@ private: struct MappedItem { MappedItem() - : data(0), size(0) { + : data(0), + size(0) { } std::string name; @@ -153,10 +154,12 @@ template class ConstantVector { public: ConstantVector() - : data_(0), size_(0) {} + : data_(0), + size_(0) {} ConstantVector(const T* mapped, size_t size) - : data_(mapped), size_(size) {} + : data_(mapped), + size_(size) {} ConstantVector(const std::vector& data) : editable_(data), diff --git a/src/Core/MatrixParser.hh b/src/Core/MatrixParser.hh index ff079c2cb..78ce21d6e 100644 --- a/src/Core/MatrixParser.hh +++ b/src/Core/MatrixParser.hh @@ -35,11 +35,8 @@ class XmlMatrixElement : public XmlBuilderElement, XmlElement, CreateByContext> { public: - typedef XmlBuilderElement, - XmlElement, - CreateByContext> - Predecessor; - typedef XmlMatrixElement Self; + typedef XmlBuilderElement, XmlElement, CreateByContext> Predecessor; + typedef XmlMatrixElement Self; typedef Math::Matrix* (XmlContext::*CreationHandler)(const XmlAttributes atts); private: @@ -124,7 +121,8 @@ class XmlMatrixDocument : public XmlSchemaParser { private: XmlMatrixElement* matrixElement_; Math::Matrix& matrix_; - Math::Matrix* pseudoCreateMatrix(const XmlAttributes atts) { + + Math::Matrix* pseudoCreateMatrix(const XmlAttributes atts) { return &matrix_; } diff --git a/src/Core/MemoryInfo.cc b/src/Core/MemoryInfo.cc index 3a4f5af54..8da0fc165 100644 --- a/src/Core/MemoryInfo.cc +++ b/src/Core/MemoryInfo.cc @@ -22,7 +22,13 @@ using namespace Core; MemoryInfo::MemoryInfo() - : size_(0), rss_(0), share_(0), text_(0), lib_(0), data_(0), pageSize_(0) { + : size_(0), + rss_(0), + share_(0), + text_(0), + lib_(0), + data_(0), + pageSize_(0) { update(); } diff --git a/src/Core/ObjectCache.hh b/src/Core/ObjectCache.hh index 7df0f390d..4c6f33ef2 100644 --- a/src/Core/ObjectCache.hh +++ b/src/Core/ObjectCache.hh @@ -46,7 +46,8 @@ private: public: ObjectCacheItem(Data* data) - : data_(data), dirty_(true) { + : data_(data), + dirty_(true) { ensure(data); } ~ObjectCacheItem() { diff --git a/src/Core/Parameter.cc b/src/Core/Parameter.cc index 9f5cdee6c..fb716d06a 100644 --- a/src/Core/Parameter.cc +++ b/src/Core/Parameter.cc @@ -42,7 +42,8 @@ AbstractParameter::AbstractParameter(const char* _ident, const char* _short_desc, const char* _long_desc, bool needs_arg) - : ident(_ident), g_param_max_ident(0) { + : ident(_ident), + g_param_max_ident(0) { require(!ident.empty()); require(ident.find(Configuration::resource_wildcard_char) == std::string::npos); require(ident.find(Configuration::resource_separation_char) == std::string::npos); @@ -339,10 +340,11 @@ VectorParameter::VectorParameter(const char* ident, maxValue_(maxValue), minSize_(minSize), maxSize_(maxSize) { - if (delimiter_ == "") + if (delimiter_ == "") { delimiter_ = " "; - require(minValue_ <= maxValue_); - require(minSize_ <= maxSize_); + } + require_le(minValue_, maxValue_); + require_le(minSize_, maxSize_); } template diff --git a/src/Core/PriorityQueue.hh b/src/Core/PriorityQueue.hh index c66b8c40e..181fb1a58 100644 --- a/src/Core/PriorityQueue.hh +++ b/src/Core/PriorityQueue.hh @@ -42,7 +42,8 @@ public: PriorityQueueBase(u32 maxSize = Type::max) : maxSize_(maxSize) {} PriorityQueueBase(const PriorityFunction& precedes, u32 maxSize = Type::max) - : maxSize_(maxSize), precedes_(precedes) {} + : maxSize_(maxSize), + precedes_(precedes) {} /** Return reference to top-most item in the queue */ const Item& top() const { @@ -127,7 +128,8 @@ public: protected: std::vector heap_; - bool invariant() const { + + bool invariant() const { return true; } diff --git a/src/Core/ProgressIndicator.cc b/src/Core/ProgressIndicator.cc index c4e757cc0..9927bbb80 100644 --- a/src/Core/ProgressIndicator.cc +++ b/src/Core/ProgressIndicator.cc @@ -220,7 +220,13 @@ void ProgressIndicator::sigWinchHandler(int sig) { ProgressIndicator::ProgressIndicator(const std::string& task, const std::string& unit) - : align_(Left), task_(task), unit_(unit), done_(0), isVisible_(false), draw(0), write_return_val_(0) {} + : align_(Left), + task_(task), + unit_(unit), + done_(0), + isVisible_(false), + draw(0), + write_return_val_(0) {} ProgressIndicator::~ProgressIndicator() { if (activeInstance == this) diff --git a/src/Core/ProgressIndicator.hh b/src/Core/ProgressIndicator.hh index 10893a910..9017a80ef 100644 --- a/src/Core/ProgressIndicator.hh +++ b/src/Core/ProgressIndicator.hh @@ -64,8 +64,10 @@ namespace Core { class ProgressIndicator { public: - enum Alignment { Left, - Right }; + enum Alignment { + Left, + Right + }; private: static const int defaultLength; diff --git a/src/Core/ReferenceCounting.cc b/src/Core/ReferenceCounting.cc index 1192e835c..f5adfef93 100644 --- a/src/Core/ReferenceCounting.cc +++ b/src/Core/ReferenceCounting.cc @@ -13,6 +13,9 @@ * limitations under the License. */ #include "ReferenceCounting.hh" + +#include + #include "Hash.hh" #include "Utility.hh" diff --git a/src/Core/ReferenceCounting.hh b/src/Core/ReferenceCounting.hh index ac423b05b..ac7127f6a 100644 --- a/src/Core/ReferenceCounting.hh +++ b/src/Core/ReferenceCounting.hh @@ -67,9 +67,11 @@ protected: public: ReferenceCounted() - : referenceCount_(0), weak_refs_(nullptr) {} + : referenceCount_(0), + weak_refs_(nullptr) {} ReferenceCounted(const ReferenceCounted&) - : referenceCount_(0), weak_refs_(nullptr) {} + : referenceCount_(0), + weak_refs_(nullptr) {} ReferenceCounted& operator=(const ReferenceCounted&) { return *this; } diff --git a/src/Core/ResourceUsageInfo.cc b/src/Core/ResourceUsageInfo.cc index 81433b86d..0813fd4c1 100644 --- a/src/Core/ResourceUsageInfo.cc +++ b/src/Core/ResourceUsageInfo.cc @@ -20,7 +20,15 @@ using namespace Core; ResourceUsageInfo::ResourceUsageInfo() - : maxrss_(0), minflt_(0), majflt_(0), inblock_(0), outblock_(0), nvcsw_(0), nicsw_(0), utime_(0), stime_(0) { + : maxrss_(0), + minflt_(0), + majflt_(0), + inblock_(0), + outblock_(0), + nvcsw_(0), + nicsw_(0), + utime_(0), + stime_(0) { update(); } diff --git a/src/Core/StringExpression.hh b/src/Core/StringExpression.hh index 55ccd16ae..4ff26074f 100644 --- a/src/Core/StringExpression.hh +++ b/src/Core/StringExpression.hh @@ -44,7 +44,8 @@ private: Token() : set_(false) {} Token(const std::string& value) - : set_(true), value_(value) {} + : set_(true), + value_(value) {} void operator=(const std::string& v) { set_ = true; @@ -134,7 +135,9 @@ public: StringExpressionParser(StringExpression& toBuild, const std::string& openTag = "$(", const std::string& closeTag = ")") - : toBuild_(toBuild), openTag_(openTag), closeTag_(closeTag) {} + : toBuild_(toBuild), + openTag_(openTag), + closeTag_(closeTag) {} bool accept(const std::string& stringExpression); }; diff --git a/src/Core/TextStream.cc b/src/Core/TextStream.cc index c3d32ca57..404dc19f8 100644 --- a/src/Core/TextStream.cc +++ b/src/Core/TextStream.cc @@ -42,9 +42,11 @@ class TextOutputStream::Buffer : public std::streambuf { void chooseFastPath(); std::string pre_; /**< unformatted */ std::string pend_; /**< formatted */ - enum { startOfLine, - leftMarginOfLine, - withinLine } lineState_; + enum { + startOfLine, + leftMarginOfLine, + withinLine + } lineState_; u32 protection_; /**< number of times "protect" has been seen */ u32 position_, length_; void startNewLine(); @@ -301,13 +303,15 @@ std::streamsize TextOutputStream::Buffer::xsputn(const char* s, std::streamsize // =========================================================================== TextOutputStream::TextOutputStream() - : std::ostream(new Buffer(this)), output_(0) { + : std::ostream(new Buffer(this)), + output_(0) { buffer_ = dynamic_cast(rdbuf()); setEncoding(); } TextOutputStream::TextOutputStream(std::ostream* s) - : std::ostream(new Buffer(this)), output_(s) { + : std::ostream(new Buffer(this)), + output_(s) { buffer_ = dynamic_cast(rdbuf()); buffer_->setOutput(output_->rdbuf()); setEncoding(); @@ -475,13 +479,15 @@ int TextInputStream::Buffer::underflow() { // =========================================================================== TextInputStream::TextInputStream() - : std::istream(new Buffer(this)), input_(0) { + : std::istream(new Buffer(this)), + input_(0) { buffer_ = dynamic_cast(rdbuf()); setEncoding(); } TextInputStream::TextInputStream(std::istream* s) - : std::istream(new Buffer(this)), input_(s) { + : std::istream(new Buffer(this)), + input_(s) { buffer_ = dynamic_cast(rdbuf()); buffer_->setInput(input_->rdbuf()); setEncoding(); diff --git a/src/Core/Thread.hh b/src/Core/Thread.hh index 0d5146458..434727122 100644 --- a/src/Core/Thread.hh +++ b/src/Core/Thread.hh @@ -194,7 +194,8 @@ template class LockingPointer { public: LockingPointer(volatile T& obj, const volatile Mutex& mutex) - : obj_(const_cast(&obj)), mutex_(const_cast(&mutex)) { + : obj_(const_cast(&obj)), + mutex_(const_cast(&mutex)) { mutex_->lock(); } ~LockingPointer() { diff --git a/src/Core/ThreadPool.hh b/src/Core/ThreadPool.hh index e29d60d8d..73a32f6f9 100644 --- a/src/Core/ThreadPool.hh +++ b/src/Core/ThreadPool.hh @@ -60,7 +60,8 @@ public: typedef typename Pool::Mapper Mapper; ThreadPoolThread(Pool* pool, Mapper* mapper) - : pool_(pool), mapper_(mapper) {} + : pool_(pool), + mapper_(mapper) {} virtual ~ThreadPoolThread() {} void run() { @@ -92,7 +93,9 @@ public: typedef ThreadPoolThread WorkerThread; ThreadPoolImpl() - : active_threads_(0), running_threads_(0), terminate_(false) {} + : active_threads_(0), + running_threads_(0), + terminate_(false) {} ~ThreadPoolImpl() { for (typename std::vector::iterator t = threads_.begin(); t != threads_.end(); ++t) { delete (*t)->getMapper(); diff --git a/src/Core/ThreadSafeReference.hh b/src/Core/ThreadSafeReference.hh index a1e52c68f..192d963ad 100644 --- a/src/Core/ThreadSafeReference.hh +++ b/src/Core/ThreadSafeReference.hh @@ -38,10 +38,6 @@ private: mutable Core::Mutex mutex_; private: - static inline ThreadSafeReferenceCounted* sentinel() { - static ThreadSafeReferenceCounted sentinel_(1); - return &sentinel_; - } static bool isSentinel(const ThreadSafeReferenceCounted* object) { return object == sentinel(); } @@ -50,6 +46,10 @@ private: } protected: + static inline ThreadSafeReferenceCounted* sentinel() { + static ThreadSafeReferenceCounted sentinel_(1); + return &sentinel_; + } virtual void free() const { verify_(isNotSentinel(this)); delete this; @@ -118,7 +118,7 @@ protected: verify_(Object::isNotSentinel(object_)); object_->release(); object_->free(); - (object_ = sentinel())->increment(); + object_ = sentinel(); } else object_->release(); diff --git a/src/Core/Tokenizer.cc b/src/Core/Tokenizer.cc index 1e73e7dda..1b69759e3 100644 --- a/src/Core/Tokenizer.cc +++ b/src/Core/Tokenizer.cc @@ -20,15 +20,20 @@ using namespace Core; const std::string StringTokenizer::whiteSpace_ = " \t\n\r\f\v"; StringTokenizer::Iterator::Iterator() - : parent_(), begin_(0), end_(0) {} + : parent_(), + begin_(0), + end_(0) {} StringTokenizer::Iterator::Iterator(const StringTokenizer* parent) - : parent_(parent), begin_(0) { + : parent_(parent), + begin_(0) { end_ = findNext(begin_); } StringTokenizer::Iterator::Iterator(const StringTokenizer* parent, size_type begin, size_type end) - : parent_(parent), begin_(begin), end_(end) {} + : parent_(parent), + begin_(begin), + end_(end) {} inline StringTokenizer::Iterator::size_type StringTokenizer::Iterator::findStart(size_type begin) const { return parent_->str_.find_first_not_of(parent_->delim_, begin); @@ -65,10 +70,16 @@ std::string StringTokenizer::Iterator::operator*() const { } StringTokenizer::StringTokenizer(const std::string& text, const std::string& delimiter, bool trim) - : str_(text), delim_(delimiter), trim_(trim), endIterator_(this, std::string::npos, std::string::npos) {} + : str_(text), + delim_(delimiter), + trim_(trim), + endIterator_(this, std::string::npos, std::string::npos) {} StringTokenizer::StringTokenizer(const std::string& text) - : str_(text), delim_(whiteSpace_), trim_(true), endIterator_(this, std::string::npos, std::string::npos) {} + : str_(text), + delim_(whiteSpace_), + trim_(true), + endIterator_(this, std::string::npos, std::string::npos) {} StringTokenizer::Iterator StringTokenizer::begin() const { return Iterator(this); diff --git a/src/Core/Unicode.hh b/src/Core/Unicode.hh index bd9c79fb9..63f4f1c12 100644 --- a/src/Core/Unicode.hh +++ b/src/Core/Unicode.hh @@ -84,10 +84,12 @@ const char whitespace[] = " \t\n\r\f\v"; * bytes. The two byte values 0xfe and 0xff are illegal, they can * never occur in a UTF-8 string. */ -enum ByteType { singleByte, - multiByteHead, - multiByteTail, - illegal }; +enum ByteType { + singleByte, + multiByteHead, + multiByteTail, + illegal +}; inline ByteType byteType(char u) { if ((u & 0x80) == 0x00) return singleByte; diff --git a/src/Core/Utility.hh b/src/Core/Utility.hh index 68074f040..dc5296fc1 100644 --- a/src/Core/Utility.hh +++ b/src/Core/Utility.hh @@ -72,7 +72,7 @@ namespace Core { /** Generic unary functor for type conversion. */ template -struct conversion : public std::unary_function { +struct conversion { T operator()(S s) const { return T(s); } @@ -87,7 +87,8 @@ template class tied { public: inline tied(A& a, B& b) - : a_(a), b_(b) {} + : a_(a), + b_(b) {} template inline tied& operator=(const std::pair& p) { a_ = p.first; @@ -132,7 +133,7 @@ T abs(const std::complex& v) { /** absoluteValue: functor for absolute value */ template -struct absoluteValue : public std::unary_function { +struct absoluteValue { T operator()(T v) const { return Core::abs(v); } @@ -171,7 +172,7 @@ T maxAbsoluteElement(const std::vector>& v) { /** power: functor for pow function */ template -struct power : public std::binary_function { +struct power { T operator()(T x, T y) const { return pow(x, y); } @@ -179,7 +180,7 @@ struct power : public std::binary_function { /** min: functor for min function */ template -struct minimum : public std::binary_function { +struct minimum { T operator()(T x, T y) const { return std::min(x, y); } @@ -187,7 +188,7 @@ struct minimum : public std::binary_function { /** max: functor for max function */ template -struct maximum : public std::binary_function { +struct maximum { T operator()(T x, T y) const { return std::max(x, y); } @@ -274,10 +275,7 @@ bool isMalformed(InputIterator begin, InputIterator end) { /** Functor for f(g(x), h(y)) */ template -class composedBinaryFunction - : public std::binary_function { +class composedBinaryFunction { protected: F f_; G g_; @@ -285,8 +283,11 @@ protected: public: composedBinaryFunction(const F& f, const G& g, const H& h) - : f_(f), g_(g), h_(h) {} - typename F::result_type operator()(const typename G::argument_type& x, const typename H::argument_type& y) const { + : f_(f), + g_(g), + h_(h) {} + template + auto operator()(const X& x, const Y& y) const -> decltype(f_(g_(x), h_(y))) { return f_(g_(x), h_(y)); } }; diff --git a/src/Core/XTermUtilities.hh b/src/Core/XTermUtilities.hh index 695a27c01..69afe38c7 100644 --- a/src/Core/XTermUtilities.hh +++ b/src/Core/XTermUtilities.hh @@ -68,7 +68,8 @@ struct Command { struct move : public Command { u16 row, col; move(u16 row = 0, u16 col = 0) - : row(row), col(col) {} + : row(row), + col(col) {} inline void write(std::ostream& out) const { out << "\033[" << row << ";" << col << "H"; } diff --git a/src/Core/XmlBuilder.hh b/src/Core/XmlBuilder.hh index 81604a102..ec7abd731 100644 --- a/src/Core/XmlBuilder.hh +++ b/src/Core/XmlBuilder.hh @@ -55,7 +55,8 @@ public: } XmlBuilderElementTemplate(const char* _name, XmlContext* _context, Handler _handler = 0) - : Precursor(_name, _context), handler_(_handler) {} + : Precursor(_name, _context), + handler_(_handler) {} }; struct CreateStatic {}; diff --git a/src/Core/check.cc b/src/Core/check.cc index ac916f505..a139b181f 100644 --- a/src/Core/check.cc +++ b/src/Core/check.cc @@ -38,9 +38,11 @@ using namespace Core; class Tester : public Component { public: - enum Flavour { vanilla, - strawberry, - chocolate }; + enum Flavour { + vanilla, + strawberry, + chocolate + }; static const Choice flavourChoice; static const ParameterBool paramBoolean; diff --git a/src/Core/robin_hood.h b/src/Core/robin_hood.h new file mode 100644 index 000000000..0af031f5f --- /dev/null +++ b/src/Core/robin_hood.h @@ -0,0 +1,2544 @@ +// ______ _____ ______ _________ +// ______________ ___ /_ ___(_)_______ ___ /_ ______ ______ ______ / +// __ ___/_ __ \__ __ \__ / __ __ \ __ __ \_ __ \_ __ \_ __ / +// _ / / /_/ /_ /_/ /_ / _ / / / _ / / // /_/ // /_/ // /_/ / +// /_/ \____/ /_.___/ /_/ /_/ /_/ ________/_/ /_/ \____/ \____/ \__,_/ +// _/_____/ +// +// Fast & memory efficient hashtable based on robin hood hashing for C++11/14/17/20 +// https://github.com/martinus/robin-hood-hashing +// +// Licensed under the MIT License . +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2021 Martin Ankerl +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROBIN_HOOD_H_INCLUDED +#define ROBIN_HOOD_H_INCLUDED + +// see https://semver.org/ +#define ROBIN_HOOD_VERSION_MAJOR 3 // for incompatible API changes +#define ROBIN_HOOD_VERSION_MINOR 11 // for adding functionality in a backwards-compatible manner +#define ROBIN_HOOD_VERSION_PATCH 5 // for backwards-compatible bug fixes + +#include +#include +#include +#include +#include +#include // only to support hash of smart pointers +#include +#include +#include +#include +#if __cplusplus >= 201703L +# include +#endif + +// #define ROBIN_HOOD_LOG_ENABLED +#ifdef ROBIN_HOOD_LOG_ENABLED +# include +# define ROBIN_HOOD_LOG(...) \ + std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << __VA_ARGS__ << std::endl; +#else +# define ROBIN_HOOD_LOG(x) +#endif + +// #define ROBIN_HOOD_TRACE_ENABLED +#ifdef ROBIN_HOOD_TRACE_ENABLED +# include +# define ROBIN_HOOD_TRACE(...) \ + std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << __VA_ARGS__ << std::endl; +#else +# define ROBIN_HOOD_TRACE(x) +#endif + +// #define ROBIN_HOOD_COUNT_ENABLED +#ifdef ROBIN_HOOD_COUNT_ENABLED +# include +# define ROBIN_HOOD_COUNT(x) ++counts().x; +namespace robin_hood { +struct Counts { + uint64_t shiftUp{}; + uint64_t shiftDown{}; +}; +inline std::ostream& operator<<(std::ostream& os, Counts const& c) { + return os << c.shiftUp << " shiftUp" << std::endl << c.shiftDown << " shiftDown" << std::endl; +} + +static Counts& counts() { + static Counts counts{}; + return counts; +} +} // namespace robin_hood +#else +# define ROBIN_HOOD_COUNT(x) +#endif + +// all non-argument macros should use this facility. See +// https://www.fluentcpp.com/2019/05/28/better-macros-better-flags/ +#define ROBIN_HOOD(x) ROBIN_HOOD_PRIVATE_DEFINITION_##x() + +// mark unused members with this macro +#define ROBIN_HOOD_UNUSED(identifier) + +// bitness +#if SIZE_MAX == UINT32_MAX +# define ROBIN_HOOD_PRIVATE_DEFINITION_BITNESS() 32 +#elif SIZE_MAX == UINT64_MAX +# define ROBIN_HOOD_PRIVATE_DEFINITION_BITNESS() 64 +#else +# error Unsupported bitness +#endif + +// endianess +#ifdef _MSC_VER +# define ROBIN_HOOD_PRIVATE_DEFINITION_LITTLE_ENDIAN() 1 +# define ROBIN_HOOD_PRIVATE_DEFINITION_BIG_ENDIAN() 0 +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_LITTLE_ENDIAN() \ + (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) +# define ROBIN_HOOD_PRIVATE_DEFINITION_BIG_ENDIAN() (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +#endif + +// inline +#ifdef _MSC_VER +# define ROBIN_HOOD_PRIVATE_DEFINITION_NOINLINE() __declspec(noinline) +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_NOINLINE() __attribute__((noinline)) +#endif + +// exceptions +#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_EXCEPTIONS() 0 +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_EXCEPTIONS() 1 +#endif + +// count leading/trailing bits +#if !defined(ROBIN_HOOD_DISABLE_INTRINSICS) +# ifdef _MSC_VER +# if ROBIN_HOOD(BITNESS) == 32 +# define ROBIN_HOOD_PRIVATE_DEFINITION_BITSCANFORWARD() _BitScanForward +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_BITSCANFORWARD() _BitScanForward64 +# endif +# include +# pragma intrinsic(ROBIN_HOOD(BITSCANFORWARD)) +# define ROBIN_HOOD_COUNT_TRAILING_ZEROES(x) \ + [](size_t mask) noexcept -> int { \ + unsigned long index; \ + return ROBIN_HOOD(BITSCANFORWARD)(&index, mask) ? static_cast(index) \ + : ROBIN_HOOD(BITNESS); \ + }(x) +# else +# if ROBIN_HOOD(BITNESS) == 32 +# define ROBIN_HOOD_PRIVATE_DEFINITION_CTZ() __builtin_ctzl +# define ROBIN_HOOD_PRIVATE_DEFINITION_CLZ() __builtin_clzl +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_CTZ() __builtin_ctzll +# define ROBIN_HOOD_PRIVATE_DEFINITION_CLZ() __builtin_clzll +# endif +# define ROBIN_HOOD_COUNT_LEADING_ZEROES(x) ((x) ? ROBIN_HOOD(CLZ)(x) : ROBIN_HOOD(BITNESS)) +# define ROBIN_HOOD_COUNT_TRAILING_ZEROES(x) ((x) ? ROBIN_HOOD(CTZ)(x) : ROBIN_HOOD(BITNESS)) +# endif +#endif + +// fallthrough +#ifndef __has_cpp_attribute // For backwards compatibility +# define __has_cpp_attribute(x) 0 +#endif +#if __has_cpp_attribute(clang::fallthrough) +# define ROBIN_HOOD_PRIVATE_DEFINITION_FALLTHROUGH() [[clang::fallthrough]] +#elif __has_cpp_attribute(gnu::fallthrough) +# define ROBIN_HOOD_PRIVATE_DEFINITION_FALLTHROUGH() [[gnu::fallthrough]] +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_FALLTHROUGH() +#endif + +// likely/unlikely +#ifdef _MSC_VER +# define ROBIN_HOOD_LIKELY(condition) condition +# define ROBIN_HOOD_UNLIKELY(condition) condition +#else +# define ROBIN_HOOD_LIKELY(condition) __builtin_expect(condition, 1) +# define ROBIN_HOOD_UNLIKELY(condition) __builtin_expect(condition, 0) +#endif + +// detect if native wchar_t type is availiable in MSVC +#ifdef _MSC_VER +# ifdef _NATIVE_WCHAR_T_DEFINED +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_NATIVE_WCHART() 1 +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_NATIVE_WCHART() 0 +# endif +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_NATIVE_WCHART() 1 +#endif + +// detect if MSVC supports the pair(std::piecewise_construct_t,...) consructor being constexpr +#ifdef _MSC_VER +# if _MSC_VER <= 1900 +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 1 +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 0 +# endif +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 0 +#endif + +// workaround missing "is_trivially_copyable" in g++ < 5.0 +// See https://stackoverflow.com/a/31798726/48181 +#if defined(__GNUC__) && __GNUC__ < 5 +# define ROBIN_HOOD_IS_TRIVIALLY_COPYABLE(...) __has_trivial_copy(__VA_ARGS__) +#else +# define ROBIN_HOOD_IS_TRIVIALLY_COPYABLE(...) std::is_trivially_copyable<__VA_ARGS__>::value +#endif + +// helpers for C++ versions, see https://gcc.gnu.org/onlinedocs/cpp/Standard-Predefined-Macros.html +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX() __cplusplus +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX98() 199711L +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX11() 201103L +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX14() 201402L +#define ROBIN_HOOD_PRIVATE_DEFINITION_CXX17() 201703L + +#if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX17) +# define ROBIN_HOOD_PRIVATE_DEFINITION_NODISCARD() [[nodiscard]] +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_NODISCARD() +#endif + +namespace robin_hood { + +#if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX14) +# define ROBIN_HOOD_STD std +#else + +// c++11 compatibility layer +namespace ROBIN_HOOD_STD { +template +struct alignment_of + : std::integral_constant::type)> {}; + +template +class integer_sequence { +public: + using value_type = T; + static_assert(std::is_integral::value, "not integral type"); + static constexpr std::size_t size() noexcept { + return sizeof...(Ints); + } +}; +template +using index_sequence = integer_sequence; + +namespace detail_ { +template +struct IntSeqImpl { + using TValue = T; + static_assert(std::is_integral::value, "not integral type"); + static_assert(Begin >= 0 && Begin < End, "unexpected argument (Begin<0 || Begin<=End)"); + + template + struct IntSeqCombiner; + + template + struct IntSeqCombiner, integer_sequence> { + using TResult = integer_sequence; + }; + + using TResult = + typename IntSeqCombiner::TResult, + typename IntSeqImpl::TResult>::TResult; +}; + +template +struct IntSeqImpl { + using TValue = T; + static_assert(std::is_integral::value, "not integral type"); + static_assert(Begin >= 0, "unexpected argument (Begin<0)"); + using TResult = integer_sequence; +}; + +template +struct IntSeqImpl { + using TValue = T; + static_assert(std::is_integral::value, "not integral type"); + static_assert(Begin >= 0, "unexpected argument (Begin<0)"); + using TResult = integer_sequence; +}; +} // namespace detail_ + +template +using make_integer_sequence = typename detail_::IntSeqImpl::TResult; + +template +using make_index_sequence = make_integer_sequence; + +template +using index_sequence_for = make_index_sequence; + +} // namespace ROBIN_HOOD_STD + +#endif + +namespace detail { + +// make sure we static_cast to the correct type for hash_int +#if ROBIN_HOOD(BITNESS) == 64 +using SizeT = uint64_t; +#else +using SizeT = uint32_t; +#endif + +template +T rotr(T x, unsigned k) { + return (x >> k) | (x << (8U * sizeof(T) - k)); +} + +// This cast gets rid of warnings like "cast from 'uint8_t*' {aka 'unsigned char*'} to +// 'uint64_t*' {aka 'long unsigned int*'} increases required alignment of target type". Use with +// care! +template +inline T reinterpret_cast_no_cast_align_warning(void* ptr) noexcept { + return reinterpret_cast(ptr); +} + +template +inline T reinterpret_cast_no_cast_align_warning(void const* ptr) noexcept { + return reinterpret_cast(ptr); +} + +// make sure this is not inlined as it is slow and dramatically enlarges code, thus making other +// inlinings more difficult. Throws are also generally the slow path. +template +[[noreturn]] ROBIN_HOOD(NOINLINE) +#if ROBIN_HOOD(HAS_EXCEPTIONS) + void doThrow(Args&&... args) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay) + throw E(std::forward(args)...); +} +#else + void doThrow(Args&&... ROBIN_HOOD_UNUSED(args) /*unused*/) { + abort(); +} +#endif + +template +T* assertNotNull(T* t, Args&&... args) { + if (ROBIN_HOOD_UNLIKELY(nullptr == t)) { + doThrow(std::forward(args)...); + } + return t; +} + +template +inline T unaligned_load(void const* ptr) noexcept { + // using memcpy so we don't get into unaligned load problems. + // compiler should optimize this very well anyways. + T t; + std::memcpy(&t, ptr, sizeof(T)); + return t; +} + +// Allocates bulks of memory for objects of type T. This deallocates the memory in the destructor, +// and keeps a linked list of the allocated memory around. Overhead per allocation is the size of a +// pointer. +template +class BulkPoolAllocator { +public: + BulkPoolAllocator() noexcept = default; + + // does not copy anything, just creates a new allocator. + BulkPoolAllocator(const BulkPoolAllocator& ROBIN_HOOD_UNUSED(o) /*unused*/) noexcept + : mHead(nullptr) + , mListForFree(nullptr) {} + + BulkPoolAllocator(BulkPoolAllocator&& o) noexcept + : mHead(o.mHead) + , mListForFree(o.mListForFree) { + o.mListForFree = nullptr; + o.mHead = nullptr; + } + + BulkPoolAllocator& operator=(BulkPoolAllocator&& o) noexcept { + reset(); + mHead = o.mHead; + mListForFree = o.mListForFree; + o.mListForFree = nullptr; + o.mHead = nullptr; + return *this; + } + + BulkPoolAllocator& + // NOLINTNEXTLINE(bugprone-unhandled-self-assignment,cert-oop54-cpp) + operator=(const BulkPoolAllocator& ROBIN_HOOD_UNUSED(o) /*unused*/) noexcept { + // does not do anything + return *this; + } + + ~BulkPoolAllocator() noexcept { + reset(); + } + + // Deallocates all allocated memory. + void reset() noexcept { + while (mListForFree) { + T* tmp = *mListForFree; + ROBIN_HOOD_LOG("std::free") + std::free(mListForFree); + mListForFree = reinterpret_cast_no_cast_align_warning(tmp); + } + mHead = nullptr; + } + + // allocates, but does NOT initialize. Use in-place new constructor, e.g. + // T* obj = pool.allocate(); + // ::new (static_cast(obj)) T(); + T* allocate() { + T* tmp = mHead; + if (!tmp) { + tmp = performAllocation(); + } + + mHead = *reinterpret_cast_no_cast_align_warning(tmp); + return tmp; + } + + // does not actually deallocate but puts it in store. + // make sure you have already called the destructor! e.g. with + // obj->~T(); + // pool.deallocate(obj); + void deallocate(T* obj) noexcept { + *reinterpret_cast_no_cast_align_warning(obj) = mHead; + mHead = obj; + } + + // Adds an already allocated block of memory to the allocator. This allocator is from now on + // responsible for freeing the data (with free()). If the provided data is not large enough to + // make use of, it is immediately freed. Otherwise it is reused and freed in the destructor. + void addOrFree(void* ptr, const size_t numBytes) noexcept { + // calculate number of available elements in ptr + if (numBytes < ALIGNMENT + ALIGNED_SIZE) { + // not enough data for at least one element. Free and return. + ROBIN_HOOD_LOG("std::free") + std::free(ptr); + } else { + ROBIN_HOOD_LOG("add to buffer") + add(ptr, numBytes); + } + } + + void swap(BulkPoolAllocator& other) noexcept { + using std::swap; + swap(mHead, other.mHead); + swap(mListForFree, other.mListForFree); + } + +private: + // iterates the list of allocated memory to calculate how many to alloc next. + // Recalculating this each time saves us a size_t member. + // This ignores the fact that memory blocks might have been added manually with addOrFree. In + // practice, this should not matter much. + ROBIN_HOOD(NODISCARD) size_t calcNumElementsToAlloc() const noexcept { + auto tmp = mListForFree; + size_t numAllocs = MinNumAllocs; + + while (numAllocs * 2 <= MaxNumAllocs && tmp) { + auto x = reinterpret_cast(tmp); + tmp = *x; + numAllocs *= 2; + } + + return numAllocs; + } + + // WARNING: Underflow if numBytes < ALIGNMENT! This is guarded in addOrFree(). + void add(void* ptr, const size_t numBytes) noexcept { + const size_t numElements = (numBytes - ALIGNMENT) / ALIGNED_SIZE; + + auto data = reinterpret_cast(ptr); + + // link free list + auto x = reinterpret_cast(data); + *x = mListForFree; + mListForFree = data; + + // create linked list for newly allocated data + auto* const headT = + reinterpret_cast_no_cast_align_warning(reinterpret_cast(ptr) + ALIGNMENT); + + auto* const head = reinterpret_cast(headT); + + // Visual Studio compiler automatically unrolls this loop, which is pretty cool + for (size_t i = 0; i < numElements; ++i) { + *reinterpret_cast_no_cast_align_warning(head + i * ALIGNED_SIZE) = + head + (i + 1) * ALIGNED_SIZE; + } + + // last one points to 0 + *reinterpret_cast_no_cast_align_warning(head + (numElements - 1) * ALIGNED_SIZE) = + mHead; + mHead = headT; + } + + // Called when no memory is available (mHead == 0). + // Don't inline this slow path. + ROBIN_HOOD(NOINLINE) T* performAllocation() { + size_t const numElementsToAlloc = calcNumElementsToAlloc(); + + // alloc new memory: [prev |T, T, ... T] + size_t const bytes = ALIGNMENT + ALIGNED_SIZE * numElementsToAlloc; + ROBIN_HOOD_LOG("std::malloc " << bytes << " = " << ALIGNMENT << " + " << ALIGNED_SIZE + << " * " << numElementsToAlloc) + add(assertNotNull(std::malloc(bytes)), bytes); + return mHead; + } + + // enforce byte alignment of the T's +#if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX14) + static constexpr size_t ALIGNMENT = + (std::max)(std::alignment_of::value, std::alignment_of::value); +#else + static const size_t ALIGNMENT = + (ROBIN_HOOD_STD::alignment_of::value > ROBIN_HOOD_STD::alignment_of::value) + ? ROBIN_HOOD_STD::alignment_of::value + : +ROBIN_HOOD_STD::alignment_of::value; // the + is for walkarround +#endif + + static constexpr size_t ALIGNED_SIZE = ((sizeof(T) - 1) / ALIGNMENT + 1) * ALIGNMENT; + + static_assert(MinNumAllocs >= 1, "MinNumAllocs"); + static_assert(MaxNumAllocs >= MinNumAllocs, "MaxNumAllocs"); + static_assert(ALIGNED_SIZE >= sizeof(T*), "ALIGNED_SIZE"); + static_assert(0 == (ALIGNED_SIZE % sizeof(T*)), "ALIGNED_SIZE mod"); + static_assert(ALIGNMENT >= sizeof(T*), "ALIGNMENT"); + + T* mHead{nullptr}; + T** mListForFree{nullptr}; +}; + +template +struct NodeAllocator; + +// dummy allocator that does nothing +template +struct NodeAllocator { + + // we are not using the data, so just free it. + void addOrFree(void* ptr, size_t ROBIN_HOOD_UNUSED(numBytes) /*unused*/) noexcept { + ROBIN_HOOD_LOG("std::free") + std::free(ptr); + } +}; + +template +struct NodeAllocator : public BulkPoolAllocator {}; + +// c++14 doesn't have is_nothrow_swappable, and clang++ 6.0.1 doesn't like it either, so I'm making +// my own here. +namespace swappable { +#if ROBIN_HOOD(CXX) < ROBIN_HOOD(CXX17) +using std::swap; +template +struct nothrow { + static const bool value = noexcept(swap(std::declval(), std::declval())); +}; +#else +template +struct nothrow { + static const bool value = std::is_nothrow_swappable::value; +}; +#endif +} // namespace swappable + +} // namespace detail + +struct is_transparent_tag {}; + +// A custom pair implementation is used in the map because std::pair is not is_trivially_copyable, +// which means it would not be allowed to be used in std::memcpy. This struct is copyable, which is +// also tested. +template +struct pair { + using first_type = T1; + using second_type = T2; + + template ::value && + std::is_default_constructible::value>::type> + constexpr pair() noexcept(noexcept(U1()) && noexcept(U2())) + : first() + , second() {} + + // pair constructors are explicit so we don't accidentally call this ctor when we don't have to. + explicit constexpr pair(std::pair const& o) noexcept( + noexcept(T1(std::declval())) && noexcept(T2(std::declval()))) + : first(o.first) + , second(o.second) {} + + // pair constructors are explicit so we don't accidentally call this ctor when we don't have to. + explicit constexpr pair(std::pair&& o) noexcept(noexcept( + T1(std::move(std::declval()))) && noexcept(T2(std::move(std::declval())))) + : first(std::move(o.first)) + , second(std::move(o.second)) {} + + constexpr pair(T1&& a, T2&& b) noexcept(noexcept( + T1(std::move(std::declval()))) && noexcept(T2(std::move(std::declval())))) + : first(std::move(a)) + , second(std::move(b)) {} + + template + constexpr pair(U1&& a, U2&& b) noexcept(noexcept(T1(std::forward( + std::declval()))) && noexcept(T2(std::forward(std::declval())))) + : first(std::forward(a)) + , second(std::forward(b)) {} + + template + // MSVC 2015 produces error "C2476: ‘constexpr’ constructor does not initialize all members" + // if this constructor is constexpr +#if !ROBIN_HOOD(BROKEN_CONSTEXPR) + constexpr +#endif + pair(std::piecewise_construct_t /*unused*/, std::tuple a, + std::tuple + b) noexcept(noexcept(pair(std::declval&>(), + std::declval&>(), + ROBIN_HOOD_STD::index_sequence_for(), + ROBIN_HOOD_STD::index_sequence_for()))) + : pair(a, b, ROBIN_HOOD_STD::index_sequence_for(), + ROBIN_HOOD_STD::index_sequence_for()) { + } + + // constructor called from the std::piecewise_construct_t ctor + template + pair(std::tuple& a, std::tuple& b, ROBIN_HOOD_STD::index_sequence /*unused*/, ROBIN_HOOD_STD::index_sequence /*unused*/) noexcept( + noexcept(T1(std::forward(std::get( + std::declval&>()))...)) && noexcept(T2(std:: + forward(std::get( + std::declval&>()))...))) + : first(std::forward(std::get(a))...) + , second(std::forward(std::get(b))...) { + // make visual studio compiler happy about warning about unused a & b. + // Visual studio's pair implementation disables warning 4100. + (void)a; + (void)b; + } + + void swap(pair& o) noexcept((detail::swappable::nothrow::value) && + (detail::swappable::nothrow::value)) { + using std::swap; + swap(first, o.first); + swap(second, o.second); + } + + T1 first; // NOLINT(misc-non-private-member-variables-in-classes) + T2 second; // NOLINT(misc-non-private-member-variables-in-classes) +}; + +template +inline void swap(pair& a, pair& b) noexcept( + noexcept(std::declval&>().swap(std::declval&>()))) { + a.swap(b); +} + +template +inline constexpr bool operator==(pair const& x, pair const& y) { + return (x.first == y.first) && (x.second == y.second); +} +template +inline constexpr bool operator!=(pair const& x, pair const& y) { + return !(x == y); +} +template +inline constexpr bool operator<(pair const& x, pair const& y) noexcept(noexcept( + std::declval() < std::declval()) && noexcept(std::declval() < + std::declval())) { + return x.first < y.first || (!(y.first < x.first) && x.second < y.second); +} +template +inline constexpr bool operator>(pair const& x, pair const& y) { + return y < x; +} +template +inline constexpr bool operator<=(pair const& x, pair const& y) { + return !(x > y); +} +template +inline constexpr bool operator>=(pair const& x, pair const& y) { + return !(x < y); +} + +inline size_t hash_bytes(void const* ptr, size_t len) noexcept { + static constexpr uint64_t m = UINT64_C(0xc6a4a7935bd1e995); + static constexpr uint64_t seed = UINT64_C(0xe17a1465); + static constexpr unsigned int r = 47; + + auto const* const data64 = static_cast(ptr); + uint64_t h = seed ^ (len * m); + + size_t const n_blocks = len / 8; + for (size_t i = 0; i < n_blocks; ++i) { + auto k = detail::unaligned_load(data64 + i); + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + auto const* const data8 = reinterpret_cast(data64 + n_blocks); + switch (len & 7U) { + case 7: + h ^= static_cast(data8[6]) << 48U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 6: + h ^= static_cast(data8[5]) << 40U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 5: + h ^= static_cast(data8[4]) << 32U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 4: + h ^= static_cast(data8[3]) << 24U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 3: + h ^= static_cast(data8[2]) << 16U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 2: + h ^= static_cast(data8[1]) << 8U; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + case 1: + h ^= static_cast(data8[0]); + h *= m; + ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH + default: + break; + } + + h ^= h >> r; + + // not doing the final step here, because this will be done by keyToIdx anyways + // h *= m; + // h ^= h >> r; + return static_cast(h); +} + +inline size_t hash_int(uint64_t x) noexcept { + // tried lots of different hashes, let's stick with murmurhash3. It's simple, fast, well tested, + // and doesn't need any special 128bit operations. + x ^= x >> 33U; + x *= UINT64_C(0xff51afd7ed558ccd); + x ^= x >> 33U; + + // not doing the final step here, because this will be done by keyToIdx anyways + // x *= UINT64_C(0xc4ceb9fe1a85ec53); + // x ^= x >> 33U; + return static_cast(x); +} + +// A thin wrapper around std::hash, performing an additional simple mixing step of the result. +template +struct hash : public std::hash { + size_t operator()(T const& obj) const + noexcept(noexcept(std::declval>().operator()(std::declval()))) { + // call base hash + auto result = std::hash::operator()(obj); + // return mixed of that, to be save against identity has + return hash_int(static_cast(result)); + } +}; + +template +struct hash> { + size_t operator()(std::basic_string const& str) const noexcept { + return hash_bytes(str.data(), sizeof(CharT) * str.size()); + } +}; + +#if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX17) +template +struct hash> { + size_t operator()(std::basic_string_view const& sv) const noexcept { + return hash_bytes(sv.data(), sizeof(CharT) * sv.size()); + } +}; +#endif + +template +struct hash { + size_t operator()(T* ptr) const noexcept { + return hash_int(reinterpret_cast(ptr)); + } +}; + +template +struct hash> { + size_t operator()(std::unique_ptr const& ptr) const noexcept { + return hash_int(reinterpret_cast(ptr.get())); + } +}; + +template +struct hash> { + size_t operator()(std::shared_ptr const& ptr) const noexcept { + return hash_int(reinterpret_cast(ptr.get())); + } +}; + +template +struct hash::value>::type> { + size_t operator()(Enum e) const noexcept { + using Underlying = typename std::underlying_type::type; + return hash{}(static_cast(e)); + } +}; + +#define ROBIN_HOOD_HASH_INT(T) \ + template <> \ + struct hash { \ + size_t operator()(T const& obj) const noexcept { \ + return hash_int(static_cast(obj)); \ + } \ + } + +#if defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wuseless-cast" +#endif +// see https://en.cppreference.com/w/cpp/utility/hash +ROBIN_HOOD_HASH_INT(bool); +ROBIN_HOOD_HASH_INT(char); +ROBIN_HOOD_HASH_INT(signed char); +ROBIN_HOOD_HASH_INT(unsigned char); +ROBIN_HOOD_HASH_INT(char16_t); +ROBIN_HOOD_HASH_INT(char32_t); +#if ROBIN_HOOD(HAS_NATIVE_WCHART) +ROBIN_HOOD_HASH_INT(wchar_t); +#endif +ROBIN_HOOD_HASH_INT(short); +ROBIN_HOOD_HASH_INT(unsigned short); +ROBIN_HOOD_HASH_INT(int); +ROBIN_HOOD_HASH_INT(unsigned int); +ROBIN_HOOD_HASH_INT(long); +ROBIN_HOOD_HASH_INT(long long); +ROBIN_HOOD_HASH_INT(unsigned long); +ROBIN_HOOD_HASH_INT(unsigned long long); +#if defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic pop +#endif +namespace detail { + +template +struct void_type { + using type = void; +}; + +template +struct has_is_transparent : public std::false_type {}; + +template +struct has_is_transparent::type> + : public std::true_type {}; + +// using wrapper classes for hash and key_equal prevents the diamond problem when the same type +// is used. see https://stackoverflow.com/a/28771920/48181 +template +struct WrapHash : public T { + WrapHash() = default; + explicit WrapHash(T const& o) noexcept(noexcept(T(std::declval()))) + : T(o) {} +}; + +template +struct WrapKeyEqual : public T { + WrapKeyEqual() = default; + explicit WrapKeyEqual(T const& o) noexcept(noexcept(T(std::declval()))) + : T(o) {} +}; + +// A highly optimized hashmap implementation, using the Robin Hood algorithm. +// +// In most cases, this map should be usable as a drop-in replacement for std::unordered_map, but +// be about 2x faster in most cases and require much less allocations. +// +// This implementation uses the following memory layout: +// +// [Node, Node, ... Node | info, info, ... infoSentinel ] +// +// * Node: either a DataNode that directly has the std::pair as member, +// or a DataNode with a pointer to std::pair. Which DataNode representation to use +// depends on how fast the swap() operation is. Heuristically, this is automatically choosen +// based on sizeof(). there are always 2^n Nodes. +// +// * info: Each Node in the map has a corresponding info byte, so there are 2^n info bytes. +// Each byte is initialized to 0, meaning the corresponding Node is empty. Set to 1 means the +// corresponding node contains data. Set to 2 means the corresponding Node is filled, but it +// actually belongs to the previous position and was pushed out because that place is already +// taken. +// +// * infoSentinel: Sentinel byte set to 1, so that iterator's ++ can stop at end() without the +// need for a idx variable. +// +// According to STL, order of templates has effect on throughput. That's why I've moved the +// boolean to the front. +// https://www.reddit.com/r/cpp/comments/ahp6iu/compile_time_binary_size_reductions_and_cs_future/eeguck4/ +template +class Table + : public WrapHash, + public WrapKeyEqual, + detail::NodeAllocator< + typename std::conditional< + std::is_void::value, Key, + robin_hood::pair::type, T>>::type, + 4, 16384, IsFlat> { +public: + static constexpr bool is_flat = IsFlat; + static constexpr bool is_map = !std::is_void::value; + static constexpr bool is_set = !is_map; + static constexpr bool is_transparent = + has_is_transparent::value && has_is_transparent::value; + + using key_type = Key; + using mapped_type = T; + using value_type = typename std::conditional< + is_set, Key, + robin_hood::pair::type, T>>::type; + using size_type = size_t; + using hasher = Hash; + using key_equal = KeyEqual; + using Self = Table; + +private: + static_assert(MaxLoadFactor100 > 10 && MaxLoadFactor100 < 100, + "MaxLoadFactor100 needs to be >10 && < 100"); + + using WHash = WrapHash; + using WKeyEqual = WrapKeyEqual; + + // configuration defaults + + // make sure we have 8 elements, needed to quickly rehash mInfo + static constexpr size_t InitialNumElements = sizeof(uint64_t); + static constexpr uint32_t InitialInfoNumBits = 5; + static constexpr uint8_t InitialInfoInc = 1U << InitialInfoNumBits; + static constexpr size_t InfoMask = InitialInfoInc - 1U; + static constexpr uint8_t InitialInfoHashShift = 0; + using DataPool = detail::NodeAllocator; + + // type needs to be wider than uint8_t. + using InfoType = uint32_t; + + // DataNode //////////////////////////////////////////////////////// + + // Primary template for the data node. We have special implementations for small and big + // objects. For large objects it is assumed that swap() is fairly slow, so we allocate these + // on the heap so swap merely swaps a pointer. + template + class DataNode {}; + + // Small: just allocate on the stack. + template + class DataNode final { + public: + template + explicit DataNode(M& ROBIN_HOOD_UNUSED(map) /*unused*/, Args&&... args) noexcept( + noexcept(value_type(std::forward(args)...))) + : mData(std::forward(args)...) {} + + DataNode(M& ROBIN_HOOD_UNUSED(map) /*unused*/, DataNode&& n) noexcept( + std::is_nothrow_move_constructible::value) + : mData(std::move(n.mData)) {} + + // doesn't do anything + void destroy(M& ROBIN_HOOD_UNUSED(map) /*unused*/) noexcept {} + void destroyDoNotDeallocate() noexcept {} + + value_type const* operator->() const noexcept { + return &mData; + } + value_type* operator->() noexcept { + return &mData; + } + + const value_type& operator*() const noexcept { + return mData; + } + + value_type& operator*() noexcept { + return mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() noexcept { + return mData.first; + } + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() noexcept { + return mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type + getFirst() const noexcept { + return mData.first; + } + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() const noexcept { + return mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getSecond() noexcept { + return mData.second; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getSecond() const noexcept { + return mData.second; + } + + void swap(DataNode& o) noexcept( + noexcept(std::declval().swap(std::declval()))) { + mData.swap(o.mData); + } + + private: + value_type mData; + }; + + // big object: allocate on heap. + template + class DataNode { + public: + template + explicit DataNode(M& map, Args&&... args) + : mData(map.allocate()) { + ::new (static_cast(mData)) value_type(std::forward(args)...); + } + + DataNode(M& ROBIN_HOOD_UNUSED(map) /*unused*/, DataNode&& n) noexcept + : mData(std::move(n.mData)) {} + + void destroy(M& map) noexcept { + // don't deallocate, just put it into list of datapool. + mData->~value_type(); + map.deallocate(mData); + } + + void destroyDoNotDeallocate() noexcept { + mData->~value_type(); + } + + value_type const* operator->() const noexcept { + return mData; + } + + value_type* operator->() noexcept { + return mData; + } + + const value_type& operator*() const { + return *mData; + } + + value_type& operator*() { + return *mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() noexcept { + return mData->first; + } + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() noexcept { + return *mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type + getFirst() const noexcept { + return mData->first; + } + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getFirst() const noexcept { + return *mData; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getSecond() noexcept { + return mData->second; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::type getSecond() const noexcept { + return mData->second; + } + + void swap(DataNode& o) noexcept { + using std::swap; + swap(mData, o.mData); + } + + private: + value_type* mData; + }; + + using Node = DataNode; + + // helpers for insertKeyPrepareEmptySpot: extract first entry (only const required) + ROBIN_HOOD(NODISCARD) key_type const& getFirstConst(Node const& n) const noexcept { + return n.getFirst(); + } + + // in case we have void mapped_type, we are not using a pair, thus we just route k through. + // No need to disable this because it's just not used if not applicable. + ROBIN_HOOD(NODISCARD) key_type const& getFirstConst(key_type const& k) const noexcept { + return k; + } + + // in case we have non-void mapped_type, we have a standard robin_hood::pair + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::value, key_type const&>::type + getFirstConst(value_type const& vt) const noexcept { + return vt.first; + } + + // Cloner ////////////////////////////////////////////////////////// + + template + struct Cloner; + + // fast path: Just copy data, without allocating anything. + template + struct Cloner { + void operator()(M const& source, M& target) const { + auto const* const src = reinterpret_cast(source.mKeyVals); + auto* tgt = reinterpret_cast(target.mKeyVals); + auto const numElementsWithBuffer = target.calcNumElementsWithBuffer(target.mMask + 1); + std::copy(src, src + target.calcNumBytesTotal(numElementsWithBuffer), tgt); + } + }; + + template + struct Cloner { + void operator()(M const& s, M& t) const { + auto const numElementsWithBuffer = t.calcNumElementsWithBuffer(t.mMask + 1); + std::copy(s.mInfo, s.mInfo + t.calcNumBytesInfo(numElementsWithBuffer), t.mInfo); + + for (size_t i = 0; i < numElementsWithBuffer; ++i) { + if (t.mInfo[i]) { + ::new (static_cast(t.mKeyVals + i)) Node(t, *s.mKeyVals[i]); + } + } + } + }; + + // Destroyer /////////////////////////////////////////////////////// + + template + struct Destroyer {}; + + template + struct Destroyer { + void nodes(M& m) const noexcept { + m.mNumElements = 0; + } + + void nodesDoNotDeallocate(M& m) const noexcept { + m.mNumElements = 0; + } + }; + + template + struct Destroyer { + void nodes(M& m) const noexcept { + m.mNumElements = 0; + // clear also resets mInfo to 0, that's sometimes not necessary. + auto const numElementsWithBuffer = m.calcNumElementsWithBuffer(m.mMask + 1); + + for (size_t idx = 0; idx < numElementsWithBuffer; ++idx) { + if (0 != m.mInfo[idx]) { + Node& n = m.mKeyVals[idx]; + n.destroy(m); + n.~Node(); + } + } + } + + void nodesDoNotDeallocate(M& m) const noexcept { + m.mNumElements = 0; + // clear also resets mInfo to 0, that's sometimes not necessary. + auto const numElementsWithBuffer = m.calcNumElementsWithBuffer(m.mMask + 1); + for (size_t idx = 0; idx < numElementsWithBuffer; ++idx) { + if (0 != m.mInfo[idx]) { + Node& n = m.mKeyVals[idx]; + n.destroyDoNotDeallocate(); + n.~Node(); + } + } + } + }; + + // Iter //////////////////////////////////////////////////////////// + + struct fast_forward_tag {}; + + // generic iterator for both const_iterator and iterator. + template + // NOLINTNEXTLINE(hicpp-special-member-functions,cppcoreguidelines-special-member-functions) + class Iter { + private: + using NodePtr = typename std::conditional::type; + + public: + using difference_type = std::ptrdiff_t; + using value_type = typename Self::value_type; + using reference = typename std::conditional::type; + using pointer = typename std::conditional::type; + using iterator_category = std::forward_iterator_tag; + + // default constructed iterator can be compared to itself, but WON'T return true when + // compared to end(). + Iter() = default; + + // Rule of zero: nothing specified. The conversion constructor is only enabled for + // iterator to const_iterator, so it doesn't accidentally work as a copy ctor. + + // Conversion constructor from iterator to const_iterator. + template ::type> + // NOLINTNEXTLINE(hicpp-explicit-conversions) + Iter(Iter const& other) noexcept + : mKeyVals(other.mKeyVals) + , mInfo(other.mInfo) {} + + Iter(NodePtr valPtr, uint8_t const* infoPtr) noexcept + : mKeyVals(valPtr) + , mInfo(infoPtr) {} + + Iter(NodePtr valPtr, uint8_t const* infoPtr, + fast_forward_tag ROBIN_HOOD_UNUSED(tag) /*unused*/) noexcept + : mKeyVals(valPtr) + , mInfo(infoPtr) { + fastForward(); + } + + template ::type> + Iter& operator=(Iter const& other) noexcept { + mKeyVals = other.mKeyVals; + mInfo = other.mInfo; + return *this; + } + + // prefix increment. Undefined behavior if we are at end()! + Iter& operator++() noexcept { + mInfo++; + mKeyVals++; + fastForward(); + return *this; + } + + Iter operator++(int) noexcept { + Iter tmp = *this; + ++(*this); + return tmp; + } + + reference operator*() const { + return **mKeyVals; + } + + pointer operator->() const { + return &**mKeyVals; + } + + template + bool operator==(Iter const& o) const noexcept { + return mKeyVals == o.mKeyVals; + } + + template + bool operator!=(Iter const& o) const noexcept { + return mKeyVals != o.mKeyVals; + } + + private: + // fast forward to the next non-free info byte + // I've tried a few variants that don't depend on intrinsics, but unfortunately they are + // quite a bit slower than this one. So I've reverted that change again. See map_benchmark. + void fastForward() noexcept { + size_t n = 0; + while (0U == (n = detail::unaligned_load(mInfo))) { + mInfo += sizeof(size_t); + mKeyVals += sizeof(size_t); + } +#if defined(ROBIN_HOOD_DISABLE_INTRINSICS) + // we know for certain that within the next 8 bytes we'll find a non-zero one. + if (ROBIN_HOOD_UNLIKELY(0U == detail::unaligned_load(mInfo))) { + mInfo += 4; + mKeyVals += 4; + } + if (ROBIN_HOOD_UNLIKELY(0U == detail::unaligned_load(mInfo))) { + mInfo += 2; + mKeyVals += 2; + } + if (ROBIN_HOOD_UNLIKELY(0U == *mInfo)) { + mInfo += 1; + mKeyVals += 1; + } +#else +# if ROBIN_HOOD(LITTLE_ENDIAN) + auto inc = ROBIN_HOOD_COUNT_TRAILING_ZEROES(n) / 8; +# else + auto inc = ROBIN_HOOD_COUNT_LEADING_ZEROES(n) / 8; +# endif + mInfo += inc; + mKeyVals += inc; +#endif + } + + friend class Table; + NodePtr mKeyVals{nullptr}; + uint8_t const* mInfo{nullptr}; + }; + + //////////////////////////////////////////////////////////////////// + + // highly performance relevant code. + // Lower bits are used for indexing into the array (2^n size) + // The upper 1-5 bits need to be a reasonable good hash, to save comparisons. + template + void keyToIdx(HashKey&& key, size_t* idx, InfoType* info) const { + // In addition to whatever hash is used, add another mul & shift so we get better hashing. + // This serves as a bad hash prevention, if the given data is + // badly mixed. + auto h = static_cast(WHash::operator()(key)); + + h *= mHashMultiplier; + h ^= h >> 33U; + + // the lower InitialInfoNumBits are reserved for info. + *info = mInfoInc + static_cast((h & InfoMask) >> mInfoHashShift); + *idx = (static_cast(h) >> InitialInfoNumBits) & mMask; + } + + // forwards the index by one, wrapping around at the end + void next(InfoType* info, size_t* idx) const noexcept { + *idx = *idx + 1; + *info += mInfoInc; + } + + void nextWhileLess(InfoType* info, size_t* idx) const noexcept { + // unrolling this by hand did not bring any speedups. + while (*info < mInfo[*idx]) { + next(info, idx); + } + } + + // Shift everything up by one element. Tries to move stuff around. + void + shiftUp(size_t startIdx, + size_t const insertion_idx) noexcept(std::is_nothrow_move_assignable::value) { + auto idx = startIdx; + ::new (static_cast(mKeyVals + idx)) Node(std::move(mKeyVals[idx - 1])); + while (--idx != insertion_idx) { + mKeyVals[idx] = std::move(mKeyVals[idx - 1]); + } + + idx = startIdx; + while (idx != insertion_idx) { + ROBIN_HOOD_COUNT(shiftUp) + mInfo[idx] = static_cast(mInfo[idx - 1] + mInfoInc); + if (ROBIN_HOOD_UNLIKELY(mInfo[idx] + mInfoInc > 0xFF)) { + mMaxNumElementsAllowed = 0; + } + --idx; + } + } + + void shiftDown(size_t idx) noexcept(std::is_nothrow_move_assignable::value) { + // until we find one that is either empty or has zero offset. + // TODO(martinus) we don't need to move everything, just the last one for the same + // bucket. + mKeyVals[idx].destroy(*this); + + // until we find one that is either empty or has zero offset. + while (mInfo[idx + 1] >= 2 * mInfoInc) { + ROBIN_HOOD_COUNT(shiftDown) + mInfo[idx] = static_cast(mInfo[idx + 1] - mInfoInc); + mKeyVals[idx] = std::move(mKeyVals[idx + 1]); + ++idx; + } + + mInfo[idx] = 0; + // don't destroy, we've moved it + // mKeyVals[idx].destroy(*this); + mKeyVals[idx].~Node(); + } + + // copy of find(), except that it returns iterator instead of const_iterator. + template + ROBIN_HOOD(NODISCARD) + size_t findIdx(Other const& key) const { + size_t idx{}; + InfoType info{}; + keyToIdx(key, &idx, &info); + + do { + // unrolling this twice gives a bit of a speedup. More unrolling did not help. + if (info == mInfo[idx] && + ROBIN_HOOD_LIKELY(WKeyEqual::operator()(key, mKeyVals[idx].getFirst()))) { + return idx; + } + next(&info, &idx); + if (info == mInfo[idx] && + ROBIN_HOOD_LIKELY(WKeyEqual::operator()(key, mKeyVals[idx].getFirst()))) { + return idx; + } + next(&info, &idx); + } while (info <= mInfo[idx]); + + // nothing found! + return mMask == 0 ? 0 + : static_cast(std::distance( + mKeyVals, reinterpret_cast_no_cast_align_warning(mInfo))); + } + + void cloneData(const Table& o) { + Cloner()(o, *this); + } + + // inserts a keyval that is guaranteed to be new, e.g. when the hashmap is resized. + // @return True on success, false if something went wrong + void insert_move(Node&& keyval) { + // we don't retry, fail if overflowing + // don't need to check max num elements + if (0 == mMaxNumElementsAllowed && !try_increase_info()) { + throwOverflowError(); + } + + size_t idx{}; + InfoType info{}; + keyToIdx(keyval.getFirst(), &idx, &info); + + // skip forward. Use <= because we are certain that the element is not there. + while (info <= mInfo[idx]) { + idx = idx + 1; + info += mInfoInc; + } + + // key not found, so we are now exactly where we want to insert it. + auto const insertion_idx = idx; + auto const insertion_info = static_cast(info); + if (ROBIN_HOOD_UNLIKELY(insertion_info + mInfoInc > 0xFF)) { + mMaxNumElementsAllowed = 0; + } + + // find an empty spot + while (0 != mInfo[idx]) { + next(&info, &idx); + } + + auto& l = mKeyVals[insertion_idx]; + if (idx == insertion_idx) { + ::new (static_cast(&l)) Node(std::move(keyval)); + } else { + shiftUp(idx, insertion_idx); + l = std::move(keyval); + } + + // put at empty spot + mInfo[insertion_idx] = insertion_info; + + ++mNumElements; + } + +public: + using iterator = Iter; + using const_iterator = Iter; + + Table() noexcept(noexcept(Hash()) && noexcept(KeyEqual())) + : WHash() + , WKeyEqual() { + ROBIN_HOOD_TRACE(this) + } + + // Creates an empty hash map. Nothing is allocated yet, this happens at the first insert. + // This tremendously speeds up ctor & dtor of a map that never receives an element. The + // penalty is payed at the first insert, and not before. Lookup of this empty map works + // because everybody points to DummyInfoByte::b. parameter bucket_count is dictated by the + // standard, but we can ignore it. + explicit Table( + size_t ROBIN_HOOD_UNUSED(bucket_count) /*unused*/, const Hash& h = Hash{}, + const KeyEqual& equal = KeyEqual{}) noexcept(noexcept(Hash(h)) && noexcept(KeyEqual(equal))) + : WHash(h) + , WKeyEqual(equal) { + ROBIN_HOOD_TRACE(this) + } + + template + Table(Iter first, Iter last, size_t ROBIN_HOOD_UNUSED(bucket_count) /*unused*/ = 0, + const Hash& h = Hash{}, const KeyEqual& equal = KeyEqual{}) + : WHash(h) + , WKeyEqual(equal) { + ROBIN_HOOD_TRACE(this) + insert(first, last); + } + + Table(std::initializer_list initlist, + size_t ROBIN_HOOD_UNUSED(bucket_count) /*unused*/ = 0, const Hash& h = Hash{}, + const KeyEqual& equal = KeyEqual{}) + : WHash(h) + , WKeyEqual(equal) { + ROBIN_HOOD_TRACE(this) + insert(initlist.begin(), initlist.end()); + } + + Table(Table&& o) noexcept + : WHash(std::move(static_cast(o))) + , WKeyEqual(std::move(static_cast(o))) + , DataPool(std::move(static_cast(o))) { + ROBIN_HOOD_TRACE(this) + if (o.mMask) { + mHashMultiplier = std::move(o.mHashMultiplier); + mKeyVals = std::move(o.mKeyVals); + mInfo = std::move(o.mInfo); + mNumElements = std::move(o.mNumElements); + mMask = std::move(o.mMask); + mMaxNumElementsAllowed = std::move(o.mMaxNumElementsAllowed); + mInfoInc = std::move(o.mInfoInc); + mInfoHashShift = std::move(o.mInfoHashShift); + // set other's mask to 0 so its destructor won't do anything + o.init(); + } + } + + Table& operator=(Table&& o) noexcept { + ROBIN_HOOD_TRACE(this) + if (&o != this) { + if (o.mMask) { + // only move stuff if the other map actually has some data + destroy(); + mHashMultiplier = std::move(o.mHashMultiplier); + mKeyVals = std::move(o.mKeyVals); + mInfo = std::move(o.mInfo); + mNumElements = std::move(o.mNumElements); + mMask = std::move(o.mMask); + mMaxNumElementsAllowed = std::move(o.mMaxNumElementsAllowed); + mInfoInc = std::move(o.mInfoInc); + mInfoHashShift = std::move(o.mInfoHashShift); + WHash::operator=(std::move(static_cast(o))); + WKeyEqual::operator=(std::move(static_cast(o))); + DataPool::operator=(std::move(static_cast(o))); + + o.init(); + + } else { + // nothing in the other map => just clear us. + clear(); + } + } + return *this; + } + + Table(const Table& o) + : WHash(static_cast(o)) + , WKeyEqual(static_cast(o)) + , DataPool(static_cast(o)) { + ROBIN_HOOD_TRACE(this) + if (!o.empty()) { + // not empty: create an exact copy. it is also possible to just iterate through all + // elements and insert them, but copying is probably faster. + + auto const numElementsWithBuffer = calcNumElementsWithBuffer(o.mMask + 1); + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + + ROBIN_HOOD_LOG("std::malloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mHashMultiplier = o.mHashMultiplier; + mKeyVals = static_cast( + detail::assertNotNull(std::malloc(numBytesTotal))); + // no need for calloc because clonData does memcpy + mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); + mNumElements = o.mNumElements; + mMask = o.mMask; + mMaxNumElementsAllowed = o.mMaxNumElementsAllowed; + mInfoInc = o.mInfoInc; + mInfoHashShift = o.mInfoHashShift; + cloneData(o); + } + } + + // Creates a copy of the given map. Copy constructor of each entry is used. + // Not sure why clang-tidy thinks this doesn't handle self assignment, it does + // NOLINTNEXTLINE(bugprone-unhandled-self-assignment,cert-oop54-cpp) + Table& operator=(Table const& o) { + ROBIN_HOOD_TRACE(this) + if (&o == this) { + // prevent assigning of itself + return *this; + } + + // we keep using the old allocator and not assign the new one, because we want to keep + // the memory available. when it is the same size. + if (o.empty()) { + if (0 == mMask) { + // nothing to do, we are empty too + return *this; + } + + // not empty: destroy what we have there + // clear also resets mInfo to 0, that's sometimes not necessary. + destroy(); + init(); + WHash::operator=(static_cast(o)); + WKeyEqual::operator=(static_cast(o)); + DataPool::operator=(static_cast(o)); + + return *this; + } + + // clean up old stuff + Destroyer::value>{}.nodes(*this); + + if (mMask != o.mMask) { + // no luck: we don't have the same array size allocated, so we need to realloc. + if (0 != mMask) { + // only deallocate if we actually have data! + ROBIN_HOOD_LOG("std::free") + std::free(mKeyVals); + } + + auto const numElementsWithBuffer = calcNumElementsWithBuffer(o.mMask + 1); + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + ROBIN_HOOD_LOG("std::malloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mKeyVals = static_cast( + detail::assertNotNull(std::malloc(numBytesTotal))); + + // no need for calloc here because cloneData performs a memcpy. + mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); + // sentinel is set in cloneData + } + WHash::operator=(static_cast(o)); + WKeyEqual::operator=(static_cast(o)); + DataPool::operator=(static_cast(o)); + mHashMultiplier = o.mHashMultiplier; + mNumElements = o.mNumElements; + mMask = o.mMask; + mMaxNumElementsAllowed = o.mMaxNumElementsAllowed; + mInfoInc = o.mInfoInc; + mInfoHashShift = o.mInfoHashShift; + cloneData(o); + + return *this; + } + + // Swaps everything between the two maps. + void swap(Table& o) { + ROBIN_HOOD_TRACE(this) + using std::swap; + swap(o, *this); + } + + // Clears all data, without resizing. + void clear() { + ROBIN_HOOD_TRACE(this) + if (empty()) { + // don't do anything! also important because we don't want to write to + // DummyInfoByte::b, even though we would just write 0 to it. + return; + } + + Destroyer::value>{}.nodes(*this); + + auto const numElementsWithBuffer = calcNumElementsWithBuffer(mMask + 1); + // clear everything, then set the sentinel again + uint8_t const z = 0; + std::fill(mInfo, mInfo + calcNumBytesInfo(numElementsWithBuffer), z); + mInfo[numElementsWithBuffer] = 1; + + mInfoInc = InitialInfoInc; + mInfoHashShift = InitialInfoHashShift; + } + + // Destroys the map and all it's contents. + ~Table() { + ROBIN_HOOD_TRACE(this) + destroy(); + } + + // Checks if both tables contain the same entries. Order is irrelevant. + bool operator==(const Table& other) const { + ROBIN_HOOD_TRACE(this) + if (other.size() != size()) { + return false; + } + for (auto const& otherEntry : other) { + if (!has(otherEntry)) { + return false; + } + } + + return true; + } + + bool operator!=(const Table& other) const { + ROBIN_HOOD_TRACE(this) + return !operator==(other); + } + + template + typename std::enable_if::value, Q&>::type operator[](const key_type& key) { + ROBIN_HOOD_TRACE(this) + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) + Node(*this, std::piecewise_construct, std::forward_as_tuple(key), + std::forward_as_tuple()); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(key), std::forward_as_tuple()); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + } + + return mKeyVals[idxAndState.first].getSecond(); + } + + template + typename std::enable_if::value, Q&>::type operator[](key_type&& key) { + ROBIN_HOOD_TRACE(this) + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) + Node(*this, std::piecewise_construct, std::forward_as_tuple(std::move(key)), + std::forward_as_tuple()); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = + Node(*this, std::piecewise_construct, std::forward_as_tuple(std::move(key)), + std::forward_as_tuple()); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + } + + return mKeyVals[idxAndState.first].getSecond(); + } + + template + void insert(Iter first, Iter last) { + for (; first != last; ++first) { + // value_type ctor needed because this might be called with std::pair's + insert(value_type(*first)); + } + } + + void insert(std::initializer_list ilist) { + for (auto&& vt : ilist) { + insert(std::move(vt)); + } + } + + template + std::pair emplace(Args&&... args) { + ROBIN_HOOD_TRACE(this) + Node n{*this, std::forward(args)...}; + auto idxAndState = insertKeyPrepareEmptySpot(getFirstConst(n)); + switch (idxAndState.second) { + case InsertionState::key_found: + n.destroy(*this); + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node(*this, std::move(n)); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = std::move(n); + break; + + case InsertionState::overflow_error: + n.destroy(*this); + throwOverflowError(); + break; + } + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); + } + + template + iterator emplace_hint(const_iterator position, Args&&... args) { + (void)position; + return emplace(std::forward(args)...).first; + } + + template + std::pair try_emplace(const key_type& key, Args&&... args) { + return try_emplace_impl(key, std::forward(args)...); + } + + template + std::pair try_emplace(key_type&& key, Args&&... args) { + return try_emplace_impl(std::move(key), std::forward(args)...); + } + + template + iterator try_emplace(const_iterator hint, const key_type& key, Args&&... args) { + (void)hint; + return try_emplace_impl(key, std::forward(args)...).first; + } + + template + iterator try_emplace(const_iterator hint, key_type&& key, Args&&... args) { + (void)hint; + return try_emplace_impl(std::move(key), std::forward(args)...).first; + } + + template + std::pair insert_or_assign(const key_type& key, Mapped&& obj) { + return insertOrAssignImpl(key, std::forward(obj)); + } + + template + std::pair insert_or_assign(key_type&& key, Mapped&& obj) { + return insertOrAssignImpl(std::move(key), std::forward(obj)); + } + + template + iterator insert_or_assign(const_iterator hint, const key_type& key, Mapped&& obj) { + (void)hint; + return insertOrAssignImpl(key, std::forward(obj)).first; + } + + template + iterator insert_or_assign(const_iterator hint, key_type&& key, Mapped&& obj) { + (void)hint; + return insertOrAssignImpl(std::move(key), std::forward(obj)).first; + } + + std::pair insert(const value_type& keyval) { + ROBIN_HOOD_TRACE(this) + return emplace(keyval); + } + + iterator insert(const_iterator hint, const value_type& keyval) { + (void)hint; + return emplace(keyval).first; + } + + std::pair insert(value_type&& keyval) { + return emplace(std::move(keyval)); + } + + iterator insert(const_iterator hint, value_type&& keyval) { + (void)hint; + return emplace(std::move(keyval)).first; + } + + // Returns 1 if key is found, 0 otherwise. + size_t count(const key_type& key) const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + auto kv = mKeyVals + findIdx(key); + if (kv != reinterpret_cast_no_cast_align_warning(mInfo)) { + return 1; + } + return 0; + } + + template + // NOLINTNEXTLINE(modernize-use-nodiscard) + typename std::enable_if::type count(const OtherKey& key) const { + ROBIN_HOOD_TRACE(this) + auto kv = mKeyVals + findIdx(key); + if (kv != reinterpret_cast_no_cast_align_warning(mInfo)) { + return 1; + } + return 0; + } + + bool contains(const key_type& key) const { // NOLINT(modernize-use-nodiscard) + return 1U == count(key); + } + + template + // NOLINTNEXTLINE(modernize-use-nodiscard) + typename std::enable_if::type contains(const OtherKey& key) const { + return 1U == count(key); + } + + // Returns a reference to the value found for key. + // Throws std::out_of_range if element cannot be found + template + // NOLINTNEXTLINE(modernize-use-nodiscard) + typename std::enable_if::value, Q&>::type at(key_type const& key) { + ROBIN_HOOD_TRACE(this) + auto kv = mKeyVals + findIdx(key); + if (kv == reinterpret_cast_no_cast_align_warning(mInfo)) { + doThrow("key not found"); + } + return kv->getSecond(); + } + + // Returns a reference to the value found for key. + // Throws std::out_of_range if element cannot be found + template + // NOLINTNEXTLINE(modernize-use-nodiscard) + typename std::enable_if::value, Q const&>::type at(key_type const& key) const { + ROBIN_HOOD_TRACE(this) + auto kv = mKeyVals + findIdx(key); + if (kv == reinterpret_cast_no_cast_align_warning(mInfo)) { + doThrow("key not found"); + } + return kv->getSecond(); + } + + const_iterator find(const key_type& key) const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return const_iterator{mKeyVals + idx, mInfo + idx}; + } + + template + const_iterator find(const OtherKey& key, is_transparent_tag /*unused*/) const { + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return const_iterator{mKeyVals + idx, mInfo + idx}; + } + + template + typename std::enable_if::type // NOLINT(modernize-use-nodiscard) + find(const OtherKey& key) const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return const_iterator{mKeyVals + idx, mInfo + idx}; + } + + iterator find(const key_type& key) { + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return iterator{mKeyVals + idx, mInfo + idx}; + } + + template + iterator find(const OtherKey& key, is_transparent_tag /*unused*/) { + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return iterator{mKeyVals + idx, mInfo + idx}; + } + + template + typename std::enable_if::type find(const OtherKey& key) { + ROBIN_HOOD_TRACE(this) + const size_t idx = findIdx(key); + return iterator{mKeyVals + idx, mInfo + idx}; + } + + iterator begin() { + ROBIN_HOOD_TRACE(this) + if (empty()) { + return end(); + } + return iterator(mKeyVals, mInfo, fast_forward_tag{}); + } + const_iterator begin() const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return cbegin(); + } + const_iterator cbegin() const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + if (empty()) { + return cend(); + } + return const_iterator(mKeyVals, mInfo, fast_forward_tag{}); + } + + iterator end() { + ROBIN_HOOD_TRACE(this) + // no need to supply valid info pointer: end() must not be dereferenced, and only node + // pointer is compared. + return iterator{reinterpret_cast_no_cast_align_warning(mInfo), nullptr}; + } + const_iterator end() const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return cend(); + } + const_iterator cend() const { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return const_iterator{reinterpret_cast_no_cast_align_warning(mInfo), nullptr}; + } + + iterator erase(const_iterator pos) { + ROBIN_HOOD_TRACE(this) + // its safe to perform const cast here + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + return erase(iterator{const_cast(pos.mKeyVals), const_cast(pos.mInfo)}); + } + + // Erases element at pos, returns iterator to the next element. + iterator erase(iterator pos) { + ROBIN_HOOD_TRACE(this) + // we assume that pos always points to a valid entry, and not end(). + auto const idx = static_cast(pos.mKeyVals - mKeyVals); + + shiftDown(idx); + --mNumElements; + + if (*pos.mInfo) { + // we've backward shifted, return this again + return pos; + } + + // no backward shift, return next element + return ++pos; + } + + size_t erase(const key_type& key) { + ROBIN_HOOD_TRACE(this) + size_t idx{}; + InfoType info{}; + keyToIdx(key, &idx, &info); + + // check while info matches with the source idx + do { + if (info == mInfo[idx] && WKeyEqual::operator()(key, mKeyVals[idx].getFirst())) { + shiftDown(idx); + --mNumElements; + return 1; + } + next(&info, &idx); + } while (info <= mInfo[idx]); + + // nothing found to delete + return 0; + } + + // reserves space for the specified number of elements. Makes sure the old data fits. + // exactly the same as reserve(c). + void rehash(size_t c) { + // forces a reserve + reserve(c, true); + } + + // reserves space for the specified number of elements. Makes sure the old data fits. + // Exactly the same as rehash(c). Use rehash(0) to shrink to fit. + void reserve(size_t c) { + // reserve, but don't force rehash + reserve(c, false); + } + + // If possible reallocates the map to a smaller one. This frees the underlying table. + // Does not do anything if load_factor is too large for decreasing the table's size. + void compact() { + ROBIN_HOOD_TRACE(this) + auto newSize = InitialNumElements; + while (calcMaxNumElementsAllowed(newSize) < mNumElements && newSize != 0) { + newSize *= 2; + } + if (ROBIN_HOOD_UNLIKELY(newSize == 0)) { + throwOverflowError(); + } + + ROBIN_HOOD_LOG("newSize > mMask + 1: " << newSize << " > " << mMask << " + 1") + + // only actually do anything when the new size is bigger than the old one. This prevents to + // continuously allocate for each reserve() call. + if (newSize < mMask + 1) { + rehashPowerOfTwo(newSize, true); + } + } + + size_type size() const noexcept { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return mNumElements; + } + + size_type max_size() const noexcept { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return static_cast(-1); + } + + ROBIN_HOOD(NODISCARD) bool empty() const noexcept { + ROBIN_HOOD_TRACE(this) + return 0 == mNumElements; + } + + float max_load_factor() const noexcept { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return MaxLoadFactor100 / 100.0F; + } + + // Average number of elements per bucket. Since we allow only 1 per bucket + float load_factor() const noexcept { // NOLINT(modernize-use-nodiscard) + ROBIN_HOOD_TRACE(this) + return static_cast(size()) / static_cast(mMask + 1); + } + + ROBIN_HOOD(NODISCARD) size_t mask() const noexcept { + ROBIN_HOOD_TRACE(this) + return mMask; + } + + ROBIN_HOOD(NODISCARD) size_t calcMaxNumElementsAllowed(size_t maxElements) const noexcept { + if (ROBIN_HOOD_LIKELY(maxElements <= (std::numeric_limits::max)() / 100)) { + return maxElements * MaxLoadFactor100 / 100; + } + + // we might be a bit inprecise, but since maxElements is quite large that doesn't matter + return (maxElements / 100) * MaxLoadFactor100; + } + + ROBIN_HOOD(NODISCARD) size_t calcNumBytesInfo(size_t numElements) const noexcept { + // we add a uint64_t, which houses the sentinel (first byte) and padding so we can load + // 64bit types. + return numElements + sizeof(uint64_t); + } + + ROBIN_HOOD(NODISCARD) + size_t calcNumElementsWithBuffer(size_t numElements) const noexcept { + auto maxNumElementsAllowed = calcMaxNumElementsAllowed(numElements); + return numElements + (std::min)(maxNumElementsAllowed, (static_cast(0xFF))); + } + + // calculation only allowed for 2^n values + ROBIN_HOOD(NODISCARD) size_t calcNumBytesTotal(size_t numElements) const { +#if ROBIN_HOOD(BITNESS) == 64 + return numElements * sizeof(Node) + calcNumBytesInfo(numElements); +#else + // make sure we're doing 64bit operations, so we are at least safe against 32bit overflows. + auto const ne = static_cast(numElements); + auto const s = static_cast(sizeof(Node)); + auto const infos = static_cast(calcNumBytesInfo(numElements)); + + auto const total64 = ne * s + infos; + auto const total = static_cast(total64); + + if (ROBIN_HOOD_UNLIKELY(static_cast(total) != total64)) { + throwOverflowError(); + } + return total; +#endif + } + +private: + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::value, bool>::type has(const value_type& e) const { + ROBIN_HOOD_TRACE(this) + auto it = find(e.first); + return it != end() && it->second == e.second; + } + + template + ROBIN_HOOD(NODISCARD) + typename std::enable_if::value, bool>::type has(const value_type& e) const { + ROBIN_HOOD_TRACE(this) + return find(e) != end(); + } + + void reserve(size_t c, bool forceRehash) { + ROBIN_HOOD_TRACE(this) + auto const minElementsAllowed = (std::max)(c, mNumElements); + auto newSize = InitialNumElements; + while (calcMaxNumElementsAllowed(newSize) < minElementsAllowed && newSize != 0) { + newSize *= 2; + } + if (ROBIN_HOOD_UNLIKELY(newSize == 0)) { + throwOverflowError(); + } + + ROBIN_HOOD_LOG("newSize > mMask + 1: " << newSize << " > " << mMask << " + 1") + + // only actually do anything when the new size is bigger than the old one. This prevents to + // continuously allocate for each reserve() call. + if (forceRehash || newSize > mMask + 1) { + rehashPowerOfTwo(newSize, false); + } + } + + // reserves space for at least the specified number of elements. + // only works if numBuckets if power of two + // True on success, false otherwise + void rehashPowerOfTwo(size_t numBuckets, bool forceFree) { + ROBIN_HOOD_TRACE(this) + + Node* const oldKeyVals = mKeyVals; + uint8_t const* const oldInfo = mInfo; + + const size_t oldMaxElementsWithBuffer = calcNumElementsWithBuffer(mMask + 1); + + // resize operation: move stuff + initData(numBuckets); + if (oldMaxElementsWithBuffer > 1) { + for (size_t i = 0; i < oldMaxElementsWithBuffer; ++i) { + if (oldInfo[i] != 0) { + // might throw an exception, which is really bad since we are in the middle of + // moving stuff. + insert_move(std::move(oldKeyVals[i])); + // destroy the node but DON'T destroy the data. + oldKeyVals[i].~Node(); + } + } + + // this check is not necessary as it's guarded by the previous if, but it helps + // silence g++'s overeager "attempt to free a non-heap object 'map' + // [-Werror=free-nonheap-object]" warning. + if (oldKeyVals != reinterpret_cast_no_cast_align_warning(&mMask)) { + // don't destroy old data: put it into the pool instead + if (forceFree) { + std::free(oldKeyVals); + } else { + DataPool::addOrFree(oldKeyVals, calcNumBytesTotal(oldMaxElementsWithBuffer)); + } + } + } + } + + ROBIN_HOOD(NOINLINE) void throwOverflowError() const { +#if ROBIN_HOOD(HAS_EXCEPTIONS) + throw std::overflow_error("robin_hood::map overflow"); +#else + abort(); +#endif + } + + template + std::pair try_emplace_impl(OtherKey&& key, Args&&... args) { + ROBIN_HOOD_TRACE(this) + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node( + *this, std::piecewise_construct, std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + break; + } + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); + } + + template + std::pair insertOrAssignImpl(OtherKey&& key, Mapped&& obj) { + ROBIN_HOOD_TRACE(this) + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + mKeyVals[idxAndState.first].getSecond() = std::forward(obj); + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node( + *this, std::piecewise_construct, std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(obj))); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(obj))); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + break; + } + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); + } + + void initData(size_t max_elements) { + mNumElements = 0; + mMask = max_elements - 1; + mMaxNumElementsAllowed = calcMaxNumElementsAllowed(max_elements); + + auto const numElementsWithBuffer = calcNumElementsWithBuffer(max_elements); + + // malloc & zero mInfo. Faster than calloc everything. + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + ROBIN_HOOD_LOG("std::calloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mKeyVals = reinterpret_cast( + detail::assertNotNull(std::malloc(numBytesTotal))); + mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); + std::memset(mInfo, 0, numBytesTotal - numElementsWithBuffer * sizeof(Node)); + + // set sentinel + mInfo[numElementsWithBuffer] = 1; + + mInfoInc = InitialInfoInc; + mInfoHashShift = InitialInfoHashShift; + } + + enum class InsertionState { overflow_error, key_found, new_node, overwrite_node }; + + // Finds key, and if not already present prepares a spot where to pot the key & value. + // This potentially shifts nodes out of the way, updates mInfo and number of inserted + // elements, so the only operation left to do is create/assign a new node at that spot. + template + std::pair insertKeyPrepareEmptySpot(OtherKey&& key) { + for (int i = 0; i < 256; ++i) { + size_t idx{}; + InfoType info{}; + keyToIdx(key, &idx, &info); + nextWhileLess(&info, &idx); + + // while we potentially have a match + while (info == mInfo[idx]) { + if (WKeyEqual::operator()(key, mKeyVals[idx].getFirst())) { + // key already exists, do NOT insert. + // see http://en.cppreference.com/w/cpp/container/unordered_map/insert + return std::make_pair(idx, InsertionState::key_found); + } + next(&info, &idx); + } + + // unlikely that this evaluates to true + if (ROBIN_HOOD_UNLIKELY(mNumElements >= mMaxNumElementsAllowed)) { + if (!increase_size()) { + return std::make_pair(size_t(0), InsertionState::overflow_error); + } + continue; + } + + // key not found, so we are now exactly where we want to insert it. + auto const insertion_idx = idx; + auto const insertion_info = info; + if (ROBIN_HOOD_UNLIKELY(insertion_info + mInfoInc > 0xFF)) { + mMaxNumElementsAllowed = 0; + } + + // find an empty spot + while (0 != mInfo[idx]) { + next(&info, &idx); + } + + if (idx != insertion_idx) { + shiftUp(idx, insertion_idx); + } + // put at empty spot + mInfo[insertion_idx] = static_cast(insertion_info); + ++mNumElements; + return std::make_pair(insertion_idx, idx == insertion_idx + ? InsertionState::new_node + : InsertionState::overwrite_node); + } + + // enough attempts failed, so finally give up. + return std::make_pair(size_t(0), InsertionState::overflow_error); + } + + bool try_increase_info() { + ROBIN_HOOD_LOG("mInfoInc=" << mInfoInc << ", numElements=" << mNumElements + << ", maxNumElementsAllowed=" + << calcMaxNumElementsAllowed(mMask + 1)) + if (mInfoInc <= 2) { + // need to be > 2 so that shift works (otherwise undefined behavior!) + return false; + } + // we got space left, try to make info smaller + mInfoInc = static_cast(mInfoInc >> 1U); + + // remove one bit of the hash, leaving more space for the distance info. + // This is extremely fast because we can operate on 8 bytes at once. + ++mInfoHashShift; + auto const numElementsWithBuffer = calcNumElementsWithBuffer(mMask + 1); + + for (size_t i = 0; i < numElementsWithBuffer; i += 8) { + auto val = unaligned_load(mInfo + i); + val = (val >> 1U) & UINT64_C(0x7f7f7f7f7f7f7f7f); + std::memcpy(mInfo + i, &val, sizeof(val)); + } + // update sentinel, which might have been cleared out! + mInfo[numElementsWithBuffer] = 1; + + mMaxNumElementsAllowed = calcMaxNumElementsAllowed(mMask + 1); + return true; + } + + // True if resize was possible, false otherwise + bool increase_size() { + // nothing allocated yet? just allocate InitialNumElements + if (0 == mMask) { + initData(InitialNumElements); + return true; + } + + auto const maxNumElementsAllowed = calcMaxNumElementsAllowed(mMask + 1); + if (mNumElements < maxNumElementsAllowed && try_increase_info()) { + return true; + } + + ROBIN_HOOD_LOG("mNumElements=" << mNumElements << ", maxNumElementsAllowed=" + << maxNumElementsAllowed << ", load=" + << (static_cast(mNumElements) * 100.0 / + (static_cast(mMask) + 1))) + + if (mNumElements * 2 < calcMaxNumElementsAllowed(mMask + 1)) { + // we have to resize, even though there would still be plenty of space left! + // Try to rehash instead. Delete freed memory so we don't steadyily increase mem in case + // we have to rehash a few times + nextHashMultiplier(); + rehashPowerOfTwo(mMask + 1, true); + } else { + // we've reached the capacity of the map, so the hash seems to work nice. Keep using it. + rehashPowerOfTwo((mMask + 1) * 2, false); + } + return true; + } + + void nextHashMultiplier() { + // adding an *even* number, so that the multiplier will always stay odd. This is necessary + // so that the hash stays a mixing function (and thus doesn't have any information loss). + mHashMultiplier += UINT64_C(0xc4ceb9fe1a85ec54); + } + + void destroy() { + if (0 == mMask) { + // don't deallocate! + return; + } + + Destroyer::value>{} + .nodesDoNotDeallocate(*this); + + // This protection against not deleting mMask shouldn't be needed as it's sufficiently + // protected with the 0==mMask check, but I have this anyways because g++ 7 otherwise + // reports a compile error: attempt to free a non-heap object 'fm' + // [-Werror=free-nonheap-object] + if (mKeyVals != reinterpret_cast_no_cast_align_warning(&mMask)) { + ROBIN_HOOD_LOG("std::free") + std::free(mKeyVals); + } + } + + void init() noexcept { + mKeyVals = reinterpret_cast_no_cast_align_warning(&mMask); + mInfo = reinterpret_cast(&mMask); + mNumElements = 0; + mMask = 0; + mMaxNumElementsAllowed = 0; + mInfoInc = InitialInfoInc; + mInfoHashShift = InitialInfoHashShift; + } + + // members are sorted so no padding occurs + uint64_t mHashMultiplier = UINT64_C(0xc4ceb9fe1a85ec53); // 8 byte 8 + Node* mKeyVals = reinterpret_cast_no_cast_align_warning(&mMask); // 8 byte 16 + uint8_t* mInfo = reinterpret_cast(&mMask); // 8 byte 24 + size_t mNumElements = 0; // 8 byte 32 + size_t mMask = 0; // 8 byte 40 + size_t mMaxNumElementsAllowed = 0; // 8 byte 48 + InfoType mInfoInc = InitialInfoInc; // 4 byte 52 + InfoType mInfoHashShift = InitialInfoHashShift; // 4 byte 56 + // 16 byte 56 if NodeAllocator +}; + +} // namespace detail + +// map + +template , + typename KeyEqual = std::equal_to, size_t MaxLoadFactor100 = 80> +using unordered_flat_map = detail::Table; + +template , + typename KeyEqual = std::equal_to, size_t MaxLoadFactor100 = 80> +using unordered_node_map = detail::Table; + +template , + typename KeyEqual = std::equal_to, size_t MaxLoadFactor100 = 80> +using unordered_map = + detail::Table) <= sizeof(size_t) * 6 && + std::is_nothrow_move_constructible>::value && + std::is_nothrow_move_assignable>::value, + MaxLoadFactor100, Key, T, Hash, KeyEqual>; + +// set + +template , typename KeyEqual = std::equal_to, + size_t MaxLoadFactor100 = 80> +using unordered_flat_set = detail::Table; + +template , typename KeyEqual = std::equal_to, + size_t MaxLoadFactor100 = 80> +using unordered_node_set = detail::Table; + +template , typename KeyEqual = std::equal_to, + size_t MaxLoadFactor100 = 80> +using unordered_set = detail::Table::value && + std::is_nothrow_move_assignable::value, + MaxLoadFactor100, Key, void, Hash, KeyEqual>; + +} // namespace robin_hood + +#endif diff --git a/src/Flf/Archive.cc b/src/Flf/Archive.cc index 506f26a44..f7160cd04 100644 --- a/src/Flf/Archive.cc +++ b/src/Flf/Archive.cc @@ -217,7 +217,8 @@ class LatticeArchiveReaderNode : public Node { public: LatticeArchiveReaderNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), reader_(0) {} + : Precursor(name, config), + reader_(0) {} virtual ~LatticeArchiveReaderNode() {} virtual void init(const std::vector& arguments) { @@ -285,7 +286,8 @@ class LatticeArchiveWriterNode : public Node { public: LatticeArchiveWriterNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), writer_(0) {} + : Precursor(name, config), + writer_(0) {} virtual ~LatticeArchiveWriterNode() {} virtual void init(const std::vector& arguments) { @@ -407,7 +409,8 @@ ConfusionNetworkArchive::ConfusionNetworkArchive( const Core::Configuration& config, const std::string& pathname, CnFormat format) - : Precursor(config, pathname), format(format) { + : Precursor(config, pathname), + format(format) { encoding = paramEncoding(config); } @@ -586,7 +589,8 @@ class ConfusionNetworkArchiveReaderNode : public Node { public: ConfusionNetworkArchiveReaderNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), archiveReader_(0) { + : Precursor(name, config), + archiveReader_(0) { isValid_ = false; } virtual ~ConfusionNetworkArchiveReaderNode() {} @@ -644,7 +648,8 @@ class ConfusionNetworkArchiveWriterNode : public Node { public: ConfusionNetworkArchiveWriterNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), archiveWriter_(0) {} + : Precursor(name, config), + archiveWriter_(0) {} virtual ~ConfusionNetworkArchiveWriterNode() {} virtual void init(const std::vector& arguments) { if (!connected(0)) @@ -706,7 +711,8 @@ PosteriorCnArchive::PosteriorCnArchive( const Core::Configuration& config, const std::string& pathname, PosteriorCnFormat format) - : Precursor(config, pathname), format(format) { + : Precursor(config, pathname), + format(format) { encoding = paramEncoding(config); } @@ -884,7 +890,8 @@ class PosteriorCnArchiveReaderNode : public Node { public: PosteriorCnArchiveReaderNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), archiveReader_(0) { + : Precursor(name, config), + archiveReader_(0) { isValid_ = false; } virtual ~PosteriorCnArchiveReaderNode() {} @@ -946,7 +953,9 @@ class PosteriorCnArchiveWriterNode : public Node { public: PosteriorCnArchiveWriterNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), archiveWriter_(0), flowCache_(0) {} + : Precursor(name, config), + archiveWriter_(0), + flowCache_(0) {} virtual ~PosteriorCnArchiveWriterNode() { delete archiveWriter_; delete flowCache_; diff --git a/src/Flf/Best.cc b/src/Flf/Best.cc index 950d40eb8..a6225cdb9 100644 --- a/src/Flf/Best.cc +++ b/src/Flf/Best.cc @@ -192,9 +192,11 @@ struct StateIdPair : public std::pair { typedef std::pair Precursor; StateIdPair(Fsa::StateId first, Fsa::StateId second) : Precursor(first, second) {} + struct Hash : public std::hash { typedef std::hash Precursor; - size_t operator()(const StateIdPair& pair) const { + + size_t operator()(const StateIdPair& pair) const { return Precursor::operator()(pair.first | ~pair.second); } }; @@ -213,7 +215,8 @@ class AllPairsShortestDistance::Internal : public std::unordered_map& arguments) { f32 tmp = paramTimeThreshold(config); diff --git a/src/Flf/CenterFrameConfusionNetworkBuilder.cc b/src/Flf/CenterFrameConfusionNetworkBuilder.cc index 1f19125b0..0228b8bda 100644 --- a/src/Flf/CenterFrameConfusionNetworkBuilder.cc +++ b/src/Flf/CenterFrameConfusionNetworkBuilder.cc @@ -120,7 +120,8 @@ class CenterFrameCn : public Core::ReferenceCounted { public: StateIndexBuilder(ConstLatticeRef l, Core::Vector& stateIndex) - : TraverseState(l), stateIndex(stateIndex) {} + : TraverseState(l), + stateIndex(stateIndex) {} void build() { if (l->getTopologicalSort()) @@ -469,7 +470,8 @@ class CenterFrameCnBuilder : public Core::ReferenceCounted { Properties props; IdList arcIds; Frame() - : t(0.0), updated(true) {} + : t(0.0), + updated(true) {} }; typedef std::vector FrameList; @@ -718,7 +720,9 @@ class CenterFrameCnBuilder : public Core::ReferenceCounted { } CenterFrameCnBuilder() - : staticL_(0), staticB_(0), feCn_(0) {} + : staticL_(0), + staticB_(0), + feCn_(0) {} public: ~CenterFrameCnBuilder() {} @@ -911,7 +915,8 @@ class CenterFrameCnBuilderNode : public Node { public: CenterFrameCnBuilderNode(const std::string& name, const Core::Configuration& config) - : Node(name, config), n_(0) { + : Node(name, config), + n_(0) { confidenceId_ = Semiring::InvalidId; } virtual ~CenterFrameCnBuilderNode() {} diff --git a/src/Flf/Combination.cc b/src/Flf/Combination.cc index 774801845..379a2a883 100644 --- a/src/Flf/Combination.cc +++ b/src/Flf/Combination.cc @@ -31,7 +31,8 @@ class SemiringCombinationHelper::Internal { public: Internal(const KeyList& comboKeys, const ScoreList& comboScales) - : comboKeys(comboKeys), comboScales(comboScales) {} + : comboKeys(comboKeys), + comboScales(comboScales) {} virtual ~Internal() {} virtual bool update(const ConstSemiringRefList& semirings) = 0; diff --git a/src/Flf/Compose.cc b/src/Flf/Compose.cc index 0480d73d5..18c77071f 100644 --- a/src/Flf/Compose.cc +++ b/src/Flf/Compose.cc @@ -41,7 +41,9 @@ class UnweightLattice : public SlaveLattice { public: UnweightLattice(ConstLatticeRef l, ConstSemiringRef semiring) - : Precursor(l), semiring_(semiring), one_(semiring->one()) { + : Precursor(l), + semiring_(semiring), + one_(semiring->one()) { setBoundaries(InvalidBoundaries); } virtual ~UnweightLattice() {} diff --git a/src/Flf/Concatenate.cc b/src/Flf/Concatenate.cc index 6b88a2ec5..0c2f0b4c3 100644 --- a/src/Flf/Concatenate.cc +++ b/src/Flf/Concatenate.cc @@ -45,7 +45,10 @@ struct BoundedSegment { Time offset; s32 startFrame, endFrame; BoundedSegment() - : segment(), offset(0), startFrame(Core::Type::max), endFrame(Core::Type::max) {} + : segment(), + offset(0), + startFrame(Core::Type::max), + endFrame(Core::Type::max) {} }; typedef std::vector BoundedSegmentList; s32 startFrame, endFrame; @@ -227,7 +230,8 @@ class ConcatenateNode : public Node { public: ConcatenateNode(const std::string& name, const Core::Configuration& config) - : Node(name, config), dump_(config, "dump") {} + : Node(name, config), + dump_(config, "dump") {} virtual ~ConcatenateNode() {} virtual ConstSegmentRef sendSegment(Port to) { @@ -346,7 +350,11 @@ class ConcatenateLatticesNode : public ConcatenateNode { public: ConcatenatedLatticeBuilder() - : s_(0), b_(0), maxSid_(0), finalSid_(0), endTime_(0) {} + : s_(0), + b_(0), + maxSid_(0), + finalSid_(0), + endTime_(0) {} ~ConcatenatedLatticeBuilder() { delete s_; } @@ -383,7 +391,8 @@ class ConcatenateLatticesNode : public ConcatenateNode { public: ConcatenateLatticesNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), reader_(0) {} + : Precursor(name, config), + reader_(0) {} virtual ~ConcatenateLatticesNode() {} virtual void init(const std::vector& arguments) { @@ -466,7 +475,8 @@ class ConcatenateFCnsNode : public ConcatenateNode { public: ConcatenateFCnsNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), reader_(0) {} + : Precursor(name, config), + reader_(0) {} virtual ~ConcatenateFCnsNode() {} virtual void init(const std::vector& arguments) { diff --git a/src/Flf/ConfusionNetworkCombination.cc b/src/Flf/ConfusionNetworkCombination.cc index bb717cd69..e81eb858f 100644 --- a/src/Flf/ConfusionNetworkCombination.cc +++ b/src/Flf/ConfusionNetworkCombination.cc @@ -51,7 +51,8 @@ class ConfusionNetworkAlignment { public: Cost() - : refPtr_(0), hypPtr_(0) {} + : refPtr_(0), + hypPtr_(0) {} virtual ~Cost() {} // called before the alignment of ref and hyp starts, i.e. before the first call to the cost functions virtual void init(const ConfusionNetwork& ref, const ConfusionNetwork& hyp) const { @@ -499,7 +500,12 @@ class ConfusionNetworkCombination : public ConfusionNetworkAlignment, public Cor public: WeightedCost() - : refWeight(0.5), hypWeight(0.5), refNorm(1.0), hypNorm(1.0), refNormedWeight(0.5), hypNormedWeight(0.5) {} + : refWeight(0.5), + hypWeight(0.5), + refNorm(1.0), + hypNorm(1.0), + refNormedWeight(0.5), + hypNormedWeight(0.5) {} virtual ~WeightedCost() {} void setPosteriorIds(ScoreId refPosteriorId, ScoreId hypPosteriorId) { verify((refPosteriorId != Semiring::InvalidId) && (hypPosteriorId != Semiring::InvalidId)); @@ -1102,7 +1108,9 @@ class RoverCombination : public ConfusionNetworkAlignment, public Core::Referenc public: NistScliteCost(std::string null = "@") - : Precursor(), null(null), nullLabel(Fsa::Epsilon) {} + : Precursor(), + null(null), + nullLabel(Fsa::Epsilon) {} void setAlphabet(Fsa::ConstAlphabetRef alphabet) { this->alphabet = alphabet; @@ -1229,14 +1237,12 @@ class RoverCombination : public ConfusionNetworkAlignment, public Core::Referenc virtual Score delCost(u32 refId) const { const ConfusionNetwork::Slot& ref = refSlot(refId); verify(!ref.empty()); - Score cost = Infinity, tmpCost = 0.0; - Fsa::LabelId label = Fsa::InvalidLabelId; + Score cost = Infinity, tmpCost = 0.0; for (ConfusionNetwork::Slot::const_iterator itRef = ref.begin(), endRef = ref.end(); itRef != endRef; ++itRef) { const ConfusionNetwork::Arc& refArc = *itRef; tmpCost = (refArc.label == Fsa::Epsilon) ? 0.002 : Score(refArc.duration) / 100.0 + 0.003; if (tmpCost < cost) { - label = refArc.label; - cost = tmpCost; + cost = tmpCost; } } return cost; @@ -1251,8 +1257,7 @@ class RoverCombination : public ConfusionNetworkAlignment, public Core::Referenc cost = delCost(refId) - 0.001; } else { - Score tmpCost = 0.0; - Fsa::LabelId label = Fsa::InvalidLabelId; + Score tmpCost = 0.0; for (ConfusionNetwork::Slot::const_iterator itRef = ref.begin(), endRef = ref.end(); itRef != endRef; ++itRef) { const ConfusionNetwork::Arc& refArc = *itRef; if (refArc.label != Fsa::Epsilon) { @@ -1266,8 +1271,7 @@ class RoverCombination : public ConfusionNetworkAlignment, public Core::Referenc if (refArc.label != hypArc.label) tmpCost += 0.001; if (tmpCost < cost) { - label = refArc.label; - cost = tmpCost; + cost = tmpCost; } } } @@ -1844,7 +1848,8 @@ class OracleAlignment : public ConfusionNetworkAlignment, public Core::Reference public: WeightedOracleCost(f64 alpha) - : PosteriorCost(), alpha_(alpha) {} + : PosteriorCost(), + alpha_(alpha) {} virtual std::string describe() const { return Core::form("weighted-oracle-cost(alpha=%.2f)", alpha_); diff --git a/src/Flf/ConfusionNetworkIo.cc b/src/Flf/ConfusionNetworkIo.cc index 282e3dd1d..2461acbe1 100644 --- a/src/Flf/ConfusionNetworkIo.cc +++ b/src/Flf/ConfusionNetworkIo.cc @@ -48,7 +48,9 @@ class CnSortingIterator { private: CnSortingIterator(ConstConfusionNetworkRef cn) - : cn_(cn), posteriorId_(Semiring::InvalidId), slot_(0) { + : cn_(cn), + posteriorId_(Semiring::InvalidId), + slot_(0) { verify(cn); if (cn->isNormalized()) { posteriorId_ = cn->normalizedProperties->posteriorId; @@ -428,7 +430,8 @@ class ConfusionNetworkXmlParser : public Core::XmlSchemaParser { public: ConfusionNetworkXmlParser(const Core::Configuration& config) - : Precursor(config), cn_(0) { + : Precursor(config), + cn_(0) { oracleAlignment_ = 0; semiringXmlElement_ = Semiring::xmlElement(this); oracleSemiringXmlElement_ = Semiring::xmlElement(this); diff --git a/src/Flf/Convert.cc b/src/Flf/Convert.cc index 992555723..a1c4c5505 100644 --- a/src/Flf/Convert.cc +++ b/src/Flf/Convert.cc @@ -32,14 +32,15 @@ class SymbolicMapping : public std::vector { : Precursor(capacity) {} static const Core::ParameterStringVector paramRow; - static SymbolicMapping loadSymbolicMapping(const Core::Configuration& config, ConstSemiringRef target) { + + static SymbolicMapping loadSymbolicMapping(const Core::Configuration& config, ConstSemiringRef target) { SymbolicMapping symbolicMapping(target->size()); SymbolicMapping::iterator itSymbolicMapping = symbolicMapping.begin(); for (KeyList::const_iterator itKey = target->keys().begin(); itKey != target->keys().end(); ++itKey, ++itSymbolicMapping) { if (itKey->empty()) Core::Application::us()->warning("Dimension with no symbolic identifier in \"%s\"", - target->name().c_str()); + target->name().c_str()); else *itSymbolicMapping = paramRow(Core::Configuration(config, *itKey)); } @@ -57,7 +58,8 @@ class MappingBuilder { public: MappingBuilder(ConstSemiringRef source, ConstSemiringRef target) - : source_(source), target_(target) {} + : source_(source), + target_(target) {} ProjectionMatrix buildLinearMapping(ScoreId offset = 0, bool scaled = true) const { ProjectionMatrix scaledMapping(target_->size()); @@ -114,7 +116,9 @@ class ProjectSemiringLattice : public ModifyLattice { public: ProjectSemiringLattice(ConstLatticeRef l, ConstSemiringRef targetSemiring, const ProjectionMatrix& mapping) - : Precursor(l), semiring_(targetSemiring), mapping_(mapping) {} + : Precursor(l), + semiring_(targetSemiring), + mapping_(mapping) {} virtual ~ProjectSemiringLattice() {} virtual ConstSemiringRef semiring() const { @@ -240,7 +244,9 @@ class FsaToLattice : public Lattice { public: WeightToScoresMap(ConstSemiringRef semiring, ScoreId id, ScoresRef scores) - : semiring_(semiring), id_(id), scores_(scores) { + : semiring_(semiring), + id_(id), + scores_(scores) { verify(id < semiring->size()); } virtual ScoresRef operator()(Fsa::Weight w) const { @@ -274,7 +280,12 @@ class FsaToLattice : public Lattice { FsaToLattice(Fsa::ConstAutomatonRef fsa, ConstSemiringRef semiring, WeightMapRef weightMap, Lexicon::AlphabetMapRef inputMap, Lexicon::AlphabetMapRef outputMap) - : fsa_(fsa), isAcceptor_(fsa->type() == Fsa::TypeAcceptor), semiring_(semiring), weightMap_(weightMap), inputMap_(inputMap), outputMap_(outputMap) { + : fsa_(fsa), + isAcceptor_(fsa->type() == Fsa::TypeAcceptor), + semiring_(semiring), + weightMap_(weightMap), + inputMap_(inputMap), + outputMap_(outputMap) { setProperties(fsa->knownProperties(), fsa->properties()); } diff --git a/src/Flf/Copy.cc b/src/Flf/Copy.cc index d8083fa51..31402d281 100644 --- a/src/Flf/Copy.cc +++ b/src/Flf/Copy.cc @@ -264,7 +264,9 @@ class BoundariesCopyBuilder : public TraverseState { public: BoundariesCopyBuilder(ConstLatticeRef l, StaticBoundaries* staticBoundaries) - : Precursor(l), boundaries_(l->getBoundaries()), staticBoundaries_(staticBoundaries) { + : Precursor(l), + boundaries_(l->getBoundaries()), + staticBoundaries_(staticBoundaries) { staticBoundaries->clear(); traverse(); } diff --git a/src/Flf/CorpusProcessor.cc b/src/Flf/CorpusProcessor.cc index 5aff35a30..69488dd76 100644 --- a/src/Flf/CorpusProcessor.cc +++ b/src/Flf/CorpusProcessor.cc @@ -89,7 +89,8 @@ void CorpusProcessor::processSpeechSegment(Bliss::SpeechSegment* segment) { // ------------------------------------------------------------------------- SpeechSegmentNode::SpeechSegmentNode(const std::string& name, const Core::Configuration& config) - : Node(name, config), blissSpeechSegment_(0) {} + : Node(name, config), + blissSpeechSegment_(0) {} void SpeechSegmentNode::init(const std::vector& arguments) { if (!in().empty()) diff --git a/src/Flf/CorpusProcessor.hh b/src/Flf/CorpusProcessor.hh index 14606a24c..c01e94d21 100644 --- a/src/Flf/CorpusProcessor.hh +++ b/src/Flf/CorpusProcessor.hh @@ -60,15 +60,19 @@ public: SpeechSegmentNode(const std::string& name, const Core::Configuration& config); virtual void init(const std::vector& arguments); virtual void sync(); - bool synced() const { + + bool synced() const { return !blissSpeechSegment_; } + virtual bool good() { return true; } + void setSpeechSegment(Bliss::SpeechSegment* blissSpeechSegment) { blissSpeechSegment_ = blissSpeechSegment; } + virtual ConstSegmentRef sendSegment(Port to); virtual const void* sendData(Port to); }; diff --git a/src/Flf/Determinize.cc b/src/Flf/Determinize.cc index 814cea05c..659f28af1 100644 --- a/src/Flf/Determinize.cc +++ b/src/Flf/Determinize.cc @@ -60,7 +60,8 @@ class DeterminizeNode : public FilterNode { public: DeterminizeNode(const std::string& name, const Core::Configuration& config) - : FilterNode(name, config), toLogSemiring_(false) {} + : FilterNode(name, config), + toLogSemiring_(false) {} ~DeterminizeNode() {} virtual void init(const std::vector& arguments) { toLogSemiring_ = paramToLogSemiring(config); diff --git a/src/Flf/EpsilonRemoval.cc b/src/Flf/EpsilonRemoval.cc index 4eaa8cf8b..07d07020d 100644 --- a/src/Flf/EpsilonRemoval.cc +++ b/src/Flf/EpsilonRemoval.cc @@ -57,7 +57,8 @@ class ArcSortLattice : public ModifyLattice { public: ArcSortLattice(ConstLatticeRef l) - : ModifyLattice(l), weakOrder_(l) { + : ModifyLattice(l), + weakOrder_(l) { this->setProperties(Fsa::PropertySorted, WeakOrder::properties()); } virtual ~ArcSortLattice() {} @@ -199,7 +200,8 @@ class ArcRemovalNode : public FilterNode { public: ArcRemovalNode(const std::string& name, const Core::Configuration& config) - : FilterNode(name, config), toLogSemiring_(false) {} + : FilterNode(name, config), + toLogSemiring_(false) {} ~ArcRemovalNode() {} virtual void init(const std::vector& arguments) { toLogSemiring_ = paramToLogSemiring(config); @@ -295,7 +297,9 @@ class NullArcFilter { const Boundaries& boundaries; const Time t; Filter(ConstLatticeRef l, ConstStateRef sr) - : l(l), boundaries(*l->getBoundaries()), t(l->getBoundaries()->get(sr->id()).time()) {} + : l(l), + boundaries(*l->getBoundaries()), + t(l->getBoundaries()->get(sr->id()).time()) {} bool operator()(const Arc& a) const { if (boundaries.get(a.target()).time() == t) { if ((a.input() != Fsa::Epsilon) || (a.output() != Fsa::Epsilon)) @@ -312,7 +316,8 @@ class NullArcFilter { const Boundaries& boundaries; Time t; WeakOrder(ConstLatticeRef l) - : l(l), boundaries(*l->getBoundaries()) {} + : l(l), + boundaries(*l->getBoundaries()) {} bool operator()(const Arc& a1, const Arc& a2) const { Time t1 = boundaries.get(a1.target()).time(), t2 = boundaries.get(a1.target()).time(); if (t1 < t2) diff --git a/src/Flf/Evaluate.cc b/src/Flf/Evaluate.cc index 7ac0e1061..27a111e3c 100644 --- a/src/Flf/Evaluate.cc +++ b/src/Flf/Evaluate.cc @@ -242,7 +242,8 @@ class EvaluatorNode : public FilterNode { public: EvaluatorNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), evaluator_(0) { + : Precursor(name, config), + evaluator_(0) { validReference_ = false; } virtual ~EvaluatorNode() { diff --git a/src/Flf/Filter.cc b/src/Flf/Filter.cc index a31573eb5..caf73bb69 100644 --- a/src/Flf/Filter.cc +++ b/src/Flf/Filter.cc @@ -30,7 +30,8 @@ class ArcFilterLattice : public SlaveLattice { public: ArcFilterLattice(ConstLatticeRef l, const ArcFilter& filter) - : Precursor(l), filter_(filter) {} + : Precursor(l), + filter_(filter) {} virtual ~ArcFilterLattice() {} virtual ConstStateRef getState(Fsa::StateId sid) const { @@ -56,7 +57,8 @@ struct InputFilter { Fsa::ConstAlphabetRef alphabet; Fsa::LabelId input; InputFilter(Fsa::ConstAlphabetRef alphabet, Fsa::LabelId input) - : alphabet(alphabet), input(input) {} + : alphabet(alphabet), + input(input) {} bool operator()(const Arc& a) const { return a.input() == input; } @@ -82,7 +84,8 @@ struct OutputFilter { Fsa::ConstAlphabetRef alphabet; Fsa::LabelId output; OutputFilter(Fsa::ConstAlphabetRef alphabet, Fsa::LabelId output) - : alphabet(alphabet), output(output) {} + : alphabet(alphabet), + output(output) {} bool operator()(const Arc& a) const { return a.output() == output; } diff --git a/src/Flf/FlfCore/Basic.cc b/src/Flf/FlfCore/Basic.cc index c3cf0942d..33ea78929 100644 --- a/src/Flf/FlfCore/Basic.cc +++ b/src/Flf/FlfCore/Basic.cc @@ -58,7 +58,9 @@ class ProjectSemiringLattice : public ModifyLattice { public: ProjectSemiringLattice(ConstLatticeRef l, ConstSemiringRef targetSemiring, const ProjectionMatrix& mapping) - : Precursor(l), semiring_(targetSemiring), mapping_(mapping) { + : Precursor(l), + semiring_(targetSemiring), + mapping_(mapping) { verify(mapping_.size() <= targetSemiring->size()); ConstSemiringRef sourceSemiring = l->semiring(); for (ProjectionMatrix::const_iterator itScales = mapping_.begin(); itScales != mapping_.end(); ++itScales) @@ -95,7 +97,9 @@ struct FsaWeightToScores { ScoresRef defaultScore; ScoreId id; FsaWeightToScores(ConstSemiringRef semiring, ScoresRef defaultScore, ScoreId id) - : semiring(semiring), defaultScore(defaultScore), id(id) {} + : semiring(semiring), + defaultScore(defaultScore), + id(id) {} ScoresRef operator()(const Fsa::Weight& w) const { ScoresRef score = semiring->clone(defaultScore); score->set(id, Score(w)); @@ -125,7 +129,11 @@ class FsaVectorLattice : public Lattice { public: FsaVectorLattice(const std::vector& fsas, ConstSemiringRef semiring) - : Precursor(), fsas_(fsas), semiring_(semiring), tmpStates_(0), tmpArcIterators_(0) { + : Precursor(), + fsas_(fsas), + semiring_(semiring), + tmpStates_(0), + tmpArcIterators_(0) { verify((fsas_.size() > 0) && (fsas_.size() == semiring_->size())); tmpStates_.resize(fsas_.size()); tmpArcIterators_.resize(fsas_.size()); @@ -199,7 +207,8 @@ struct FsaWeightToConstant { ConstSemiringRef semiring; ScoresRef constScore; FsaWeightToConstant(ConstSemiringRef semiring, ScoresRef constScore) - : semiring(semiring), constScore(constScore) {} + : semiring(semiring), + constScore(constScore) {} ScoresRef operator()(const Fsa::Weight& w) const { return constScore; } @@ -242,7 +251,8 @@ struct Projection { ScoreId id; Score scale; Projection(ScoreId id, Score scale = Score(1)) - : id(id), scale(scale) {} + : id(id), + scale(scale) {} Fsa::Weight operator()(const ScoresRef& a) const { return Fsa::Weight(scale * a->get(id)); } @@ -295,7 +305,9 @@ class TopologicalOrderBuilder : protected DfsState { public: TopologicalOrderBuilder(ConstLatticeRef l, Core::Ref map) - : DfsState(l), map_(map), isCyclic_(false) {} + : DfsState(l), + map_(map), + isCyclic_(false) {} bool isCyclic() const { return isCyclic_; } @@ -351,7 +363,8 @@ class FindTopologicalOrder : public TopologicalOrderBuilder { public: FindTopologicalOrder(ConstLatticeRef l, Core::Ref map) - : TopologicalOrderBuilder(l, map), time_(0) { + : TopologicalOrderBuilder(l, map), + time_(0) { dfs(); if (time_ > 0) { if (!isCyclic()) { @@ -397,15 +410,19 @@ ConstStateMapRef findTopologicalOrder(ConstLatticeRef l) { struct ChronologicalWeakOrder { const Boundaries& boundaries; const StateMap& topologicalOrder; - bool operator()(Fsa::StateId sid1, Fsa::StateId sid2) const { + + bool operator()(Fsa::StateId sid1, Fsa::StateId sid2) const { if (boundaries.time(sid1) == boundaries.time(sid2)) return topologicalOrder[sid1] < topologicalOrder[sid2]; else return boundaries.time(sid1) < boundaries.time(sid2); } + ChronologicalWeakOrder(const Boundaries& boundaries, const StateMap& topologicalOrder) - : boundaries(boundaries), topologicalOrder(topologicalOrder) {} + : boundaries(boundaries), + topologicalOrder(topologicalOrder) {} }; + ConstStateMapRef sortChronologically(ConstLatticeRef l) { ConstStateMapRef topologicalSort = sortTopologically(l); verify(topologicalSort); @@ -429,7 +446,8 @@ class TopologicalOrderLattice : public SlaveLattice { ConstStateMapRef topologicalSort_; // order = new-sid -> old-sid public: TopologicalOrderBoundaries(ConstBoundariesRef boundaries, ConstStateMapRef topologicalSort) - : boundaries_(boundaries), topologicalSort_(topologicalSort) {} + : boundaries_(boundaries), + topologicalSort_(topologicalSort) {} virtual ~TopologicalOrderBoundaries() {} bool valid() const { return boundaries_->valid(); @@ -536,7 +554,9 @@ class ScoreAppendLattice : public SlaveLattice { public: ScoreAppendLattice(ConstLatticeRef l1, ConstLatticeRef l2, ConstSemiringRef semiring) - : Precursor(l1), appendL_(l2), semiring_(semiring) { + : Precursor(l1), + appendL_(l2), + semiring_(semiring) { const Semiring& semiring1 = *l1->semiring(); n1_ = semiring1.size(); const Semiring& semiring2 = *l2->semiring(); diff --git a/src/Flf/FlfCore/Boundaries.hh b/src/Flf/FlfCore/Boundaries.hh index 9c88abd73..e595d72b9 100644 --- a/src/Flf/FlfCore/Boundaries.hh +++ b/src/Flf/FlfCore/Boundaries.hh @@ -70,9 +70,11 @@ public: Boundary() : time_(Speech::InvalidTimeframeIndex) {} Boundary(Speech::TimeframeIndex time) - : time_(time), transit_() {} + : time_(time), + transit_() {} Boundary(Speech::TimeframeIndex time, const Transit& transit) - : time_(time), transit_(transit) {} + : time_(time), + transit_(transit) {} void setTime(Speech::TimeframeIndex time) { time_ = time; @@ -164,7 +166,8 @@ private: public: MappedBoundaries(ConstBoundariesRef boundaries, Fsa::ConstMappingRef map) - : boundaries_(boundaries), map_(map) {} + : boundaries_(boundaries), + map_(map) {} virtual ~MappedBoundaries() {} virtual bool valid() const { diff --git a/src/Flf/FlfCore/Lattice.cc b/src/Flf/FlfCore/Lattice.cc index cb2785d12..9f21b9d2c 100644 --- a/src/Flf/FlfCore/Lattice.cc +++ b/src/Flf/FlfCore/Lattice.cc @@ -21,7 +21,8 @@ namespace Flf { Lattice::nLattices = 0; Lattice::Lattice() - : Precursor(), boundaries_(InvalidBoundaries) { + : Precursor(), + boundaries_(InvalidBoundaries) { ++nLattices; } @@ -30,7 +31,8 @@ Lattice::~Lattice() { } #else Lattice::Lattice() - : Precursor(), boundaries_(InvalidBoundaries) {} + : Precursor(), + boundaries_(InvalidBoundaries) {} Lattice::~Lattice() {} #endif diff --git a/src/Flf/FlfCore/Lattice.hh b/src/Flf/FlfCore/Lattice.hh index 220121650..bc73cad5a 100644 --- a/src/Flf/FlfCore/Lattice.hh +++ b/src/Flf/FlfCore/Lattice.hh @@ -40,9 +40,15 @@ struct Arc { Fsa::LabelId output_; Arc() {} Arc(Fsa::StateId target, ScoresRef a, Fsa::LabelId input) - : target_(target), weight_(a), input_(input), output_(input) {} + : target_(target), + weight_(a), + input_(input), + output_(input) {} Arc(Fsa::StateId target, ScoresRef a, Fsa::LabelId input, Fsa::LabelId output) - : target_(target), weight_(a), input_(input), output_(output) {} + : target_(target), + weight_(a), + input_(input), + output_(output) {} ~Arc() {} Fsa::StateId target() const { return target_; @@ -102,11 +108,14 @@ public: public: StateMap() - : Precursor(), maxSid(Fsa::InvalidStateId) {} + : Precursor(), + maxSid(Fsa::InvalidStateId) {} StateMap(size_t n) - : Precursor(n), maxSid(Fsa::InvalidStateId) {} + : Precursor(n), + maxSid(Fsa::InvalidStateId) {} StateMap(size_t n, Fsa::StateId def) - : Precursor(n, def), maxSid(Fsa::InvalidStateId) {} + : Precursor(n, def), + maxSid(Fsa::InvalidStateId) {} }; typedef Core::Ref StateMapRef; typedef Core::Ref ConstStateMapRef; @@ -179,11 +188,26 @@ struct Arc { from = to = Core::Type::max; } Arc(Label label, ScoresRef scores) - : label(label), scores(scores), begin(Speech::InvalidTimeframeIndex), duration(0), from(Core::Type::max), to(Core::Type::max) {} + : label(label), + scores(scores), + begin(Speech::InvalidTimeframeIndex), + duration(0), + from(Core::Type::max), + to(Core::Type::max) {} Arc(Label label, ScoresRef scores, Time begin, Time duration) - : label(label), scores(scores), begin(begin), duration(duration), from(Core::Type::max), to(Core::Type::max) {} + : label(label), + scores(scores), + begin(begin), + duration(duration), + from(Core::Type::max), + to(Core::Type::max) {} Arc(Label label, ScoresRef scores, Time begin, Time duration, u32 from, u32 to) - : label(label), scores(scores), begin(begin), duration(duration), from(from), to(to) {} + : label(label), + scores(scores), + begin(begin), + duration(duration), + from(from), + to(to) {} inline bool operator<(const Arc& a) const { return label < a.label; } @@ -194,7 +218,8 @@ struct PosteriorArc { Probability score; PosteriorArc() {} PosteriorArc(Label label, Probability score) - : label(label), score(score) {} + : label(label), + score(score) {} inline bool operator<(const PosteriorArc& a) const { return label < a.label; } @@ -281,7 +306,8 @@ public: Fsa::StateId sid; Fsa::StateId aid; Mapping(Fsa::StateId sid, Fsa::StateId aid) - : sid(sid), aid(aid) {} + : sid(sid), + aid(aid) {} }; static const Mapping InvalidMapping; typedef std::vector Map; diff --git a/src/Flf/FlfCore/LatticeInternal.cc b/src/Flf/FlfCore/LatticeInternal.cc index 73bd44a08..2a3d7dca8 100644 --- a/src/Flf/FlfCore/LatticeInternal.cc +++ b/src/Flf/FlfCore/LatticeInternal.cc @@ -44,7 +44,9 @@ class RescoreLattice::StateRescorer { public: StateRescorer(const RescoreLattice* l) - : l(*l), predecessor(*l->fsa_), semiring(*l->semiring()) {} + : l(*l), + predecessor(*l->fsa_), + semiring(*l->semiring()) {} virtual ~StateRescorer() {} virtual ConstStateRef operator()(Fsa::StateId sid) const = 0; }; diff --git a/src/Flf/FlfCore/Semiring.hh b/src/Flf/FlfCore/Semiring.hh index 89ba7c621..46e5ad158 100644 --- a/src/Flf/FlfCore/Semiring.hh +++ b/src/Flf/FlfCore/Semiring.hh @@ -116,13 +116,17 @@ public: * scales **/ virtual void setScales(const ScoreList& scales) const; - ScoreList& scales() const { + + ScoreList& scales() const { return scales_; } + virtual void setScale(ScoreId id, Score scale) const; - Score scale(ScoreId id) const { + + Score scale(ScoreId id) const { return scales_[id]; } + /* * return value: * 0 scales are equal diff --git a/src/Flf/FlfCore/TopologicalOrderQueue.hh b/src/Flf/FlfCore/TopologicalOrderQueue.hh index 86ecfbb24..7bdc959b9 100644 --- a/src/Flf/FlfCore/TopologicalOrderQueue.hh +++ b/src/Flf/FlfCore/TopologicalOrderQueue.hh @@ -43,7 +43,8 @@ private: public: TopologicalOrderQueue(const WeakTopologicalOrder& weakOrder) - : Precursor(weakOrder), topologicalOrder_(weakOrder.topologicalOrder) {} + : Precursor(weakOrder), + topologicalOrder_(weakOrder.topologicalOrder) {} ConstStateMapRef getTopologicalOrder() const { return topologicalOrder_; } @@ -74,7 +75,8 @@ private: public: ReverseTopologicalOrderQueue(const WeakReverseTopologicalOrder& weakOrder) - : Precursor(weakOrder), topologicalOrder_(weakOrder.topologicalOrder) {} + : Precursor(weakOrder), + topologicalOrder_(weakOrder.topologicalOrder) {} ConstStateMapRef getTopologicalOrder() const { return topologicalOrder_; } diff --git a/src/Flf/FlfCore/Utility.cc b/src/Flf/FlfCore/Utility.cc index 97224021e..8df76effa 100644 --- a/src/Flf/FlfCore/Utility.cc +++ b/src/Flf/FlfCore/Utility.cc @@ -117,8 +117,7 @@ const TextFileParser::StringList& TextFileParser::next() { n_++; std::string line; std::getline(tis_, line); - bool isEscaped = false; - const char* comment = 0; + bool isEscaped = false; for (const char* c = line.c_str();;) { for (; (*c != '\0') && ::isspace(*c); ++c) ; @@ -166,7 +165,6 @@ const TextFileParser::StringList& TextFileParser::next() { isEscaped = true; } else if ((*c == '#') || ((*c == ';') && (*(c + 1) == ';'))) { - comment = c; if (columns_.back().empty()) columns_.pop_back(); isBreak = true; diff --git a/src/Flf/FlfExt/AcousticAlignment.cc b/src/Flf/FlfExt/AcousticAlignment.cc index 50553b0b8..8c970d330 100644 --- a/src/Flf/FlfExt/AcousticAlignment.cc +++ b/src/Flf/FlfExt/AcousticAlignment.cc @@ -50,7 +50,10 @@ struct AlignedCoarticulatedLemmaPronunciation { Key( const CoarticulatedLemmaPronunciation& coLp, Time beginTime, Time endTime) - : coLp(coLp), id(coLp.object().id()), beginTime(beginTime), endTime(endTime) {} + : coLp(coLp), + id(coLp.object().id()), + beginTime(beginTime), + endTime(endTime) {} }; }; struct AlignedCoarticulatedLemmaPronunciation::Key::Hash { @@ -168,7 +171,11 @@ Time SubWordAlignment::duration() const { // ------------------------------------------------------------------------- LatticeAlignment::LatticeAlignment(ConstLatticeRef l, AcousticPhonemeSequenceAligner* aligner, const LabelMapList& subwordMaps) - : l_(l), isLemma_(false), aligner_(aligner), subwordMaps_(subwordMaps), size_(aligner->features()->size()) { + : l_(l), + isLemma_(false), + aligner_(aligner), + subwordMaps_(subwordMaps), + size_(aligner->features()->size()) { switch (Lexicon::us()->alphabetId(l_->getInputAlphabet())) { case Lexicon::LemmaAlphabetId: isLemma_ = true; @@ -432,7 +439,8 @@ struct SubwordArcAligner { const LatticeAlignment& latticeAlignment; u32 i_; SubwordArcAligner(const LatticeAlignment& latticeAlignment, u32 i) - : latticeAlignment(latticeAlignment), i_(i) {} + : latticeAlignment(latticeAlignment), + i_(i) {} Fsa::ConstAlphabetRef alphabet() const { return Lexicon::us()->lemmaPronunciationAlphabet(); } @@ -519,7 +527,10 @@ class LatticeFromLatticeAlignmentBuilder : public TraverseState { } LatticeFromLatticeAlignmentBuilder(ConstLatticeRef l, const ArcAligner& arcAligner, StaticLattice* s, StaticBoundaries* b) - : TraverseState(l), arcAligner_(arcAligner), s_(s), b_(b) { + : TraverseState(l), + arcAligner_(arcAligner), + s_(s), + b_(b) { semiring_ = l->semiring(); boundaries_ = l->getBoundaries(); Fsa::StateId initialSid = l->initialStateId(); @@ -726,7 +737,9 @@ ConstPosteriorCnRef LatticeAlignment::subwordFramePosteriorCn(ConstFwdBwdRef fb, // ------------------------------------------------------------------------- LatticeAlignmentBuilder::LatticeAlignmentBuilder(const Core::Configuration& config, AcousticPhonemeSequenceAligner* aligner, const LabelMapList& subwordMaps) - : Core::Component(config), aligner_(aligner), subwordMaps_(subwordMaps) { + : Core::Component(config), + aligner_(aligner), + subwordMaps_(subwordMaps) { if (aligner_->acousticModel()->hmmTopologySet()->getDefault().nPhoneStates() < 3) Core::Application::us()->warning("Default HMM has less than 3 states; proper function of phoneme/sub-word-unit alignment cannot be guaranted."); for (LabelMapList::const_iterator itMap = subwordMaps.begin(); itMap != subwordMaps.end(); ++itMap) { @@ -851,7 +864,12 @@ class ExtendByAcousticScoreLattice : public RescoreLattice { Score maxScore, bool scoreEpsArcs, RescoreMode rescoreMode) - : Precursor(l, rescoreMode), latticeAlignment_(latticeAlignment), id_(id), scale_(scale), maxScore_(maxScore), scoreEpsArcs_(scoreEpsArcs) { + : Precursor(l, rescoreMode), + latticeAlignment_(latticeAlignment), + id_(id), + scale_(scale), + maxScore_(maxScore), + scoreEpsArcs_(scoreEpsArcs) { require(latticeAlignment); } virtual ~ExtendByAcousticScoreLattice() {} @@ -1063,7 +1081,10 @@ class PosteriorCnFeatureLattice : public RescoreLattice { PosteriorCnFeatureLattice( ConstLatticeRef l, RescoreMode rescoreMode, const Configuration& config, ConstLatticeAlignmentRef latticeAlignment, ConstPosteriorCnRef phonemeCn) - : Precursor(l, rescoreMode), config_(config), latticeAlignment_(latticeAlignment), phonemeCn_(phonemeCn) {} + : Precursor(l, rescoreMode), + config_(config), + latticeAlignment_(latticeAlignment), + phonemeCn_(phonemeCn) {} virtual void rescore(State* sp) const { const Boundaries& boundaries = *fsa_->getBoundaries(); @@ -1277,7 +1298,8 @@ ConstLatticeRef OrthographyAlignment::lattice() const { // ------------------------------------------------------------------------- OrthographyAlignmentBuilder::OrthographyAlignmentBuilder(const Core::Configuration& config, AcousticOrthographyAligner* aligner) - : Core::Component(config), aligner_(aligner) {} + : Core::Component(config), + aligner_(aligner) {} OrthographyAlignmentBuilder::~OrthographyAlignmentBuilder() { delete aligner_; diff --git a/src/Flf/FlfExt/AcousticAlignment.hh b/src/Flf/FlfExt/AcousticAlignment.hh index de756ee3b..ac5302f74 100644 --- a/src/Flf/FlfExt/AcousticAlignment.hh +++ b/src/Flf/FlfExt/AcousticAlignment.hh @@ -36,9 +36,15 @@ struct SubWord { Time duration; Bliss::Phoneme::Id leftContext, rightContext; SubWord(Fsa::LabelId label, Time duration) - : label(label), duration(duration), leftContext(Bliss::Phoneme::term), rightContext(Bliss::Phoneme::term) {} + : label(label), + duration(duration), + leftContext(Bliss::Phoneme::term), + rightContext(Bliss::Phoneme::term) {} SubWord(Fsa::LabelId label, Time duration, Bliss::Phoneme::Id leftContext, Bliss::Phoneme::Id rightContext) - : label(label), duration(duration), leftContext(leftContext), rightContext(rightContext) {} + : label(label), + duration(duration), + leftContext(leftContext), + rightContext(rightContext) {} }; class SubWordAlignment : public std::vector, public Core::ReferenceCounted { @@ -50,11 +56,14 @@ private: public: SubWordAlignment() - : Precursor(), label_(Fsa::InvalidLabelId) {} + : Precursor(), + label_(Fsa::InvalidLabelId) {} SubWordAlignment(Fsa::LabelId label) - : Precursor(), label_(label) {} + : Precursor(), + label_(label) {} SubWordAlignment(Fsa::LabelId label, const SubWord& sw) - : Precursor(1, sw), label_(label) {} + : Precursor(1, sw), + label_(label) {} void setLabel(Fsa::LabelId label) { label_ = label; } @@ -88,7 +97,8 @@ public: void setNonWordLemmaPronunciations(const Lexicon::ConstLemmaPronunciationPtrList& nonWordLemmaProns); const AcousticPhonemeSequenceAligner* aligner() const; - u32 size() const { + + u32 size() const { return size_; } diff --git a/src/Flf/FlfExt/MtConfusionNetwork.cc b/src/Flf/FlfExt/MtConfusionNetwork.cc index 94291aa52..ad1754fa7 100644 --- a/src/Flf/FlfExt/MtConfusionNetwork.cc +++ b/src/Flf/FlfExt/MtConfusionNetwork.cc @@ -53,7 +53,9 @@ class MtCnFeatureLattice : public RescoreLattice { public: MtCnFeatureLattice(ConstLatticeRef l, ConstConfusionNetworkRef cn, RescoreMode rescoreMode, const FeatureIds& ids) - : Precursor(l, rescoreMode), cn_(cn), ids_(ids) { + : Precursor(l, rescoreMode), + cn_(cn), + ids_(ids) { verify(cn_->hasMap()); if ((ids_.confidenceId != Semiring::InvalidId) || (ids_.scoreId != Semiring::InvalidId) || (ids_.slotEntropyId != Semiring::InvalidId) || (ids_.nonEpsSlotId != Semiring::InvalidId)) normalizedCn_ = (cn_->isNormalized()) ? cn_ : normalizeCn(cn_, ids_.cnPosteriorId); @@ -190,7 +192,9 @@ class MtCnFeatureNode : public RescoreNode { Fsa::StateId bptr; Fsa::StateId aid; TraceElement() - : score(Semiring::Max), bptr(Fsa::InvalidStateId), aid(Fsa::InvalidStateId) {} + : score(Semiring::Max), + bptr(Fsa::InvalidStateId), + aid(Fsa::InvalidStateId) {} }; typedef std::vector Traceback; @@ -294,7 +298,8 @@ class MtCnFeatureNode : public RescoreNode { public: MtCnFeatureNode(const std::string& name, const Core::Configuration& config) - : Precursor(name, config), alignedBestChannel_(config, "best") {} + : Precursor(name, config), + alignedBestChannel_(config, "best") {} ~MtCnFeatureNode() {} void init(const std::vector& arguments) { diff --git a/src/Flf/FlfExt/WindowedLevenshteinDistanceDecoder.cc b/src/Flf/FlfExt/WindowedLevenshteinDistanceDecoder.cc index f041b19b0..66fdae350 100644 --- a/src/Flf/FlfExt/WindowedLevenshteinDistanceDecoder.cc +++ b/src/Flf/FlfExt/WindowedLevenshteinDistanceDecoder.cc @@ -16,7 +16,9 @@ namespace Flf { // ------------------------------------------------------------------------- ConditionalPosterior::Value::Value(Fsa::LabelId label, Score condPosteriorScore, Score tuplePosteriorScore) - : label(label), condPosteriorScore(condPosteriorScore), tuplePosteriorScore(tuplePosteriorScore) {} + : label(label), + condPosteriorScore(condPosteriorScore), + tuplePosteriorScore(tuplePosteriorScore) {} class ConditionalPosterior::Internal { friend class ConditionalPosteriorBuilder; @@ -30,7 +32,9 @@ class ConditionalPosterior::Internal { Fsa::LabelId label; u32 begin, end; Node(Fsa::LabelId label, u32 begin, u32 end) - : label(label), begin(begin), end(end) {} + : label(label), + begin(begin), + end(end) {} }; typedef std::vector NodeList; @@ -85,7 +89,8 @@ class ConditionalPosterior::Internal { private: Internal(ConstLatticeRef l, u32 windowSize) - : l_(l), windowSize_(windowSize) {} + : l_(l), + windowSize_(windowSize) {} public: inline u32 windowSize() const { @@ -234,7 +239,9 @@ class ConditionalPosteriorBuilder : public Core::ReferenceCounted { u32 arcPtrIdx; ConstArcPtrList arcPtrs; Hypothesis(Node* node, Score score, u32 nReserve) - : node(node), score(score), arcPtrIdx(0) { + : node(node), + score(score), + arcPtrIdx(0) { arcPtrs.reserve(nReserve); } }; @@ -275,7 +282,9 @@ class ConditionalPosteriorBuilder : public Core::ReferenceCounted { public: ConditionalPosteriorBuilder(u32 windowSize, bool compact) - : windowSize_(windowSize), compact_(compact), prune_(false) {} + : windowSize_(windowSize), + compact_(compact), + prune_(false) {} /* * Pruning requires a compact CN (necessary due to practical/technical issues; see below) @@ -338,7 +347,8 @@ class ConditionalPosteriorBuilder::PruningFilter : public ConditionalPosteriorBu ConstConfusionNetworkRef cn; const Fsa::LabelId invalidLabel; PruningFilter(ConstConfusionNetworkRef cn, Fsa::LabelId invalidLabel) - : cn(cn), invalidLabel(invalidLabel) {} + : cn(cn), + invalidLabel(invalidLabel) {} virtual bool keep(const Fsa::LabelId label) const { return (label == Fsa::Epsilon) ? false : true; } @@ -983,7 +993,8 @@ class ConditionalPosteriorsNode : public FilterNode { public: ConditionalPosteriorsNode(const std::string& name, const Core::Configuration& config) - : FilterNode(name, config), dumpChannel_(config, "dump") {} + : FilterNode(name, config), + dumpChannel_(config, "dump") {} virtual ~ConditionalPosteriorsNode() {} virtual void init(const std::vector& arguments) { @@ -1107,9 +1118,12 @@ class WindowedLevenshteinDistanceDecoder : public Core::ReferenceCounted { BackpointerRef backptr; BackpointerRef sideptr; Backpointer() - : hypLabel(Fsa::InvalidLabelId), refLabel(Fsa::InvalidLabelId) {} + : hypLabel(Fsa::InvalidLabelId), + refLabel(Fsa::InvalidLabelId) {} Backpointer(BackpointerRef backptr, Fsa::LabelId hypLabel, Fsa::LabelId refLabel) - : hypLabel(hypLabel), refLabel(refLabel), backptr(backptr) {} + : hypLabel(hypLabel), + refLabel(refLabel), + backptr(backptr) {} static BackpointerRef create(); static BackpointerRef extend(BackpointerRef backptr, Fsa::LabelId hypLabel, Fsa::LabelId refLabel); @@ -2087,7 +2101,9 @@ struct WindowedLevenshteinDistanceDecoder::AlignmentHelper { Fsa::StateId fromSid, toSid; BackpointerRef bptr; Arc(Fsa::StateId fromSid, Fsa::StateId toSid, BackpointerRef bptr) - : fromSid(fromSid), toSid(toSid), bptr(bptr) {} + : fromSid(fromSid), + toSid(toSid), + bptr(bptr) {} }; typedef std::vector ArcList; diff --git a/src/Flf/FlfIo.cc b/src/Flf/FlfIo.cc index 646167fe5..2ac67211a 100644 --- a/src/Flf/FlfIo.cc +++ b/src/Flf/FlfIo.cc @@ -539,7 +539,11 @@ class FlfDescriptorXmlParser : public Core::XmlSchemaParser { public: FlfDescriptorXmlParser(const Core::Configuration& config) - : Precursor(config), desc_(0), n_(0), id_(0), hasId_(-1) { + : Precursor(config), + desc_(0), + n_(0), + id_(0), + hasId_(-1) { Core::XmlMixedElement* rootElement = new Core::XmlMixedElementRelay("lattice", this, startHandler(&Self::lattice), 0, 0, XML_CHILD(new Core::XmlIgnoreElement( "head", this)), @@ -617,7 +621,8 @@ const Core::ParameterFloat FlfReader::paramScale( Semiring::DefaultScale); FlfReader::FlfReader(const Core::Configuration& config) - : Precursor(config), contextHandling_(UpdateContext) { + : Precursor(config), + contextHandling_(UpdateContext) { context_ = FlfContextRef(new FlfContext()); Core::Choice::Value contextHandling = ContextModeChoice[paramContextMode(config)]; @@ -630,7 +635,9 @@ FlfReader::FlfReader(const Core::Configuration& config) } FlfReader::FlfReader(const Core::Configuration& config, FlfContextRef context, u32 contextHandling) - : Precursor(config), contextHandling_(contextHandling), context_(context) { + : Precursor(config), + contextHandling_(contextHandling), + context_(context) { require(context_); init(); } diff --git a/src/Flf/FlfIo.hh b/src/Flf/FlfIo.hh index 2c15a0281..b1eee2e43 100644 --- a/src/Flf/FlfIo.hh +++ b/src/Flf/FlfIo.hh @@ -65,7 +65,8 @@ struct FsaDescriptor { FsaDescriptor() {} FsaDescriptor(const std::string& format, const std::string& file) - : format(format), file(file) {} + : format(format), + file(file) {} FsaDescriptor(const std::string& qualifiedFile); FsaDescriptor(const Core::Configuration& config); FsaDescriptor& operator=(const FsaDescriptor& desc) { @@ -98,6 +99,7 @@ struct BoundariesDescriptor { BoundariesDescriptor() {} BoundariesDescriptor(const std::string& file) : file(file) {} + std::string operator()(const std::string& root = "") const; operator bool() const { return !file.empty(); diff --git a/src/Flf/FwdBwd.cc b/src/Flf/FwdBwd.cc index 3bd63b226..f6d5e1427 100644 --- a/src/Flf/FwdBwd.cc +++ b/src/Flf/FwdBwd.cc @@ -36,7 +36,8 @@ namespace Flf { typedef Core::Ref FwdBwdRef; FwdBwd::State::State() - : begin_(0), end_(0) {} + : begin_(0), + end_(0) {} FwdBwd::Arc::Arc() {} @@ -46,7 +47,12 @@ struct FwdBwd::Internal { FwdBwd::Arc* arcs; f64 min, max, sum; Internal(const ConstSemiringRefList& semirings) - : semirings(semirings), states(0), arcs(0), min(Core::Type::max), max(Core::Type::min), sum(Core::Type::min) {} + : semirings(semirings), + states(0), + arcs(0), + min(Core::Type::max), + max(Core::Type::min), + sum(Core::Type::min) {} ~Internal() { delete[] states; delete[] arcs; @@ -99,7 +105,11 @@ class FwdBwd::Builder { Core::Vector finalStateIds; Core::Ref topologicalSort; Properties() - : offset(0), nInitialArcs(0), nArcs(0), startTime(0), endTime(0) {} + : offset(0), + nInitialArcs(0), + nArcs(0), + startTime(0), + endTime(0) {} }; class TraverseLatticeProperties : protected DfsState { @@ -1030,7 +1040,8 @@ class FwdBwdBuilder::Internal : public Core::Component { struct SingleConfiguration { ConstSemiringRef semiring; FwdBwd::Parameters params; - void dump(std::ostream& os) const { + + void dump(std::ostream& os) const { if (semiring) { os << "Target semiring is \"" << semiring->name() << "\"." << std::endl; if (params.scoreId != Semiring::InvalidId) @@ -1060,7 +1071,8 @@ class FwdBwdBuilder::Internal : public Core::Component { struct CombinationConfiguration { FwdBwd::CombinationParameters params; - void dump(std::ostream& os) const { + + void dump(std::ostream& os) const { if (params.combination) { os << "Target semiring is \"" << params.combination->semiring()->name() << "\"." << std::endl; if (params.scoreId != Semiring::InvalidId) @@ -1293,7 +1305,8 @@ class FwdBwdBuilderNode : public Node { public: FwdBwdBuilderNode(const std::string& name, const Core::Configuration& config) - : Node(name, config), n_(0) {} + : Node(name, config), + n_(0) {} virtual ~FwdBwdBuilderNode() {} virtual void init(const std::vector& arguments) { diff --git a/src/Flf/FwdBwd.hh b/src/Flf/FwdBwd.hh index 4d66fb477..e32eb1b8c 100644 --- a/src/Flf/FwdBwd.hh +++ b/src/Flf/FwdBwd.hh @@ -70,9 +70,11 @@ public: } typedef const Arc* const_iterator; - const_iterator begin() const { + + const_iterator begin() const { return begin_; } + const_iterator end() const { return end_; } @@ -99,7 +101,8 @@ public: f64 sum() const; const State& state(Fsa::StateId sid) const; - const Arc& arc(ConstStateRef sr, Flf::State::const_iterator a) const { + + const Arc& arc(ConstStateRef sr, Flf::State::const_iterator a) const { return *(state(sr->id()).begin() + (a - sr->begin())); } diff --git a/src/Flf/HtkSlfIo.cc b/src/Flf/HtkSlfIo.cc index acec93343..98263fdbb 100644 --- a/src/Flf/HtkSlfIo.cc +++ b/src/Flf/HtkSlfIo.cc @@ -266,7 +266,11 @@ class HtkSlfBuilder : public Core::Component { TransitState transit; Bliss::Phoneme::Id initialPhonemeId, finalPhonemeId; Node() - : sp(0), labelId(Fsa::InvalidLabelId), transit(TransitUnchecked), initialPhonemeId(InvalidPhonemeId), finalPhonemeId(InvalidPhonemeId) {} + : sp(0), + labelId(Fsa::InvalidLabelId), + transit(TransitUnchecked), + initialPhonemeId(InvalidPhonemeId), + finalPhonemeId(InvalidPhonemeId) {} }; typedef std::vector NodeList; diff --git a/src/Flf/HtkSlfIo.hh b/src/Flf/HtkSlfIo.hh index 99225a620..c1f9672e0 100644 --- a/src/Flf/HtkSlfIo.hh +++ b/src/Flf/HtkSlfIo.hh @@ -37,8 +37,10 @@ namespace Flf { * - time field is compulsory */ -typedef enum { HtkSlfForward, - HtkSlfBackward } HtkSlfType; +typedef enum { + HtkSlfForward, + HtkSlfBackward +} HtkSlfType; class HtkSlfContext; typedef Core::Ref HtkSlfContextRef; @@ -69,72 +71,88 @@ public: return silId_; } - void setType(HtkSlfType type); + void setType(HtkSlfType type); + HtkSlfType type() const { return type_; } void setFps(f32 fps); - f32 fps() const { + + f32 fps() const { return fps_; } void setCapitalize(bool isCapitalize); + bool capitalize() const { return isCapitalize_; } void setMergePenalties(bool mergePenalty); + bool mergePenalty() const { return mergePenalty_; } void setBase(f32 base); - f32 base() const { + + f32 base() const { return base_; } - void setEpsSymbol(const std::string&); + void setEpsSymbol(const std::string&); + const std::string& epsSymbol() const { return epsSymbol_; } - void setSemiring(ConstSemiringRef semiring); + void setSemiring(ConstSemiringRef semiring); + ConstSemiringRef semiring() const { return semiring_; } + ScoreId amId() const { return amId_; } + ScoreId lmId() const { return lmId_; } + ScoreId penaltyId() const { return penaltyId_; } // if silPenalty is not given or invalid, silPenalty = wrdPenalty void setPenalties(Score wrdPenalty, Score silPenalty = Semiring::Invalid); + bool hasPenalties() const { return wrdPenalty_ != Semiring::Invalid; } + Score wordPenalty() const { return wrdPenalty_; } + Score silPenalty() const { return silPenalty_; } - void setLmName(const std::string& lmName); + void setLmName(const std::string& lmName); + const std::string& lmName() const { return lmName_; } + bool cmpLm(const std::string& lmName) const { return lmName_ == lmName; } std::string info() const; - void clear(); + + void clear(); static HtkSlfContextRef create(const Core::Configuration& config); }; @@ -227,7 +245,8 @@ public: StateIdList htk2fsa; StateIdList fsa2htk; ConstStateRefList finals; - void clear() { + + void clear() { htk2fsa.clear(); fsa2htk.clear(); finals.clear(); diff --git a/src/Flf/IncrementalRecognizer.cc b/src/Flf/IncrementalRecognizer.cc index aee35e55f..cac73e0fe 100644 --- a/src/Flf/IncrementalRecognizer.cc +++ b/src/Flf/IncrementalRecognizer.cc @@ -53,14 +53,17 @@ namespace Flf { struct ForwardBackwardAlignment { struct Word { Word() - : amScore(Core::Type::max), lmScore(Core::Type::max) {} + : amScore(Core::Type::max), + lmScore(Core::Type::max) {} const Bliss::LemmaPronunciation* pron; u32 start, end; Fsa::StateId originState; Score amScore, lmScore; - bool intersects(const Word& rhs) const { + + bool intersects(const Word& rhs) const { return end >= rhs.start && start <= rhs.end && rhs.end >= start && rhs.start <= end; } + bool equals(const Word& rhs) const { return pron == rhs.pron && end == rhs.end && start == rhs.start; } @@ -603,7 +606,8 @@ class IncrementalRecognizer : public Speech::Recognizer { Core::Timer globalTimer_; u32 segmentFeatureCount_; - f32 globalRtf() const { + + f32 globalRtf() const { if (segmentFeatureCount_ == 0) return 0; else @@ -636,6 +640,7 @@ class IncrementalRecognizer : public Speech::Recognizer { u32 lmContextLength_; f32 relaxPruningFactor_, relaxPruningOffset_, latticeRelaxPruningFactor_, latticeRelaxPruningOffset_, adaptInitialUpdateRate_, adaptRelaxPruningFactor_, adaptRelaxPruningOffset_; + u32 latticeRelaxPruningInterval_; s32 adaptCorrectionRatio_; f32 scoreTolerance_; f32 adaptPruningFactor_; @@ -645,7 +650,6 @@ class IncrementalRecognizer : public Speech::Recognizer { bool onlyEnforceMinimumSearchSpace_; bool correctStrictInitial_; f32 maximumRtf_; - u32 latticeRelaxPruningInterval_; const Bliss::SpeechSegment* segment_; // Current sub-segment index, if partial lattices were returned by the decoder @@ -1511,11 +1515,12 @@ class IncrementalRecognizer : public Speech::Recognizer { l = persistent(l); } if (fwdBwdThreshold_ >= 0 || minArcsPerSecond_ || maxArcsPerSecond_ < Core::Type::max) { - l = pruneByFwdBwdScores(l, - fb, + l = pruneByFwdBwdScores(l, + fb, fwdBwdThreshold_ < 0 ? (fb->max() - fb->min()) : fwdBwdThreshold_, - minArcsPerSecond_, - maxArcsPerSecond_); + minArcsPerSecond_, + maxArcsPerSecond_); + StaticLatticeRef trimmedLattice = StaticLatticeRef(new StaticLattice); copy(l, trimmedLattice.get(), 0); trimInPlace(trimmedLattice); @@ -1558,15 +1563,16 @@ class IncrementalRecognizer : public Speech::Recognizer { mc_(mc), modelAdaptor_(SegmentwiseModelAdaptorRef(new SegmentwiseModelAdaptor(mc))), tracebackChannel_(config, "traceback"), + segmentFeatureCount_(0), lmContextLength_(paramLmContextLength(config)), relaxPruningFactor_(paramRelaxPruningFactor(config)), relaxPruningOffset_(paramRelaxPruningOffset(config)), latticeRelaxPruningFactor_(paramLatticeRelaxPruningFactor(config)), latticeRelaxPruningOffset_(paramLatticeRelaxPruningOffset(config)), - latticeRelaxPruningInterval_(paramLatticeRelaxPruningInterval(config)), adaptInitialUpdateRate_(paramAdaptInitialUpdateRate(config)), adaptRelaxPruningFactor_(paramAdaptRelaxPruningFactor(config)), adaptRelaxPruningOffset_(paramAdaptRelaxPruningOffset(config)), + latticeRelaxPruningInterval_(paramLatticeRelaxPruningInterval(config)), adaptCorrectionRatio_(paramAdaptCorrectionRatio(config)), scoreTolerance_(paramScoreTolerance(config) * mc->languageModel()->scale()), adaptPruningFactor_(paramAdaptPruningFactor(config)), @@ -1577,7 +1583,6 @@ class IncrementalRecognizer : public Speech::Recognizer { correctStrictInitial_(paramCorrectStrictInitial(config)), maximumRtf_(paramMaxRtf(config)), segment_(0), - segmentFeatureCount_(0), subSegment_(0), verboseRefinement_(paramVerboseRefinement(config)), considerSentenceBegin_(paramConsiderSentenceBegin(config)), @@ -1843,7 +1848,8 @@ class IncrementalRecognizer : public Speech::Recognizer { if (preCacheAllFrames_) { struct PreCacher : public Search::SearchAlgorithm { PreCacher() - : SearchAlgorithm(Core::Configuration()), Core::Component(Core::Configuration()) {} + : Core::Component(Core::Configuration()), + SearchAlgorithm(Core::Configuration()) {} virtual void feed(const Mm::FeatureScorer::Scorer& scorer) { dynamic_cast(scorer.get())->precache(); } @@ -1858,6 +1864,9 @@ class IncrementalRecognizer : public Speech::Recognizer { virtual bool setModelCombination(const Speech::ModelCombination& modelCombination) { return false; } + virtual bool setLanguageModel(Core::Ref) { + defect(); + } } precacher; Core::Timer timer; Speech::RecognizerDelayHandler handler(&precacher, acousticModel_, contextScorerCache_); @@ -2228,7 +2237,9 @@ class IncrementalRecognizerNode : public Node { public: IncrementalRecognizerNode(const std::string& name, const Core::Configuration& config) - : Node(name, config), mc_(), recognizer_(0) {} + : Node(name, config), + mc_(), + recognizer_(0) {} virtual ~IncrementalRecognizerNode() { delete recognizer_; } diff --git a/src/Flf/Info.cc b/src/Flf/Info.cc index d64c83a61..bde4e46d4 100644 --- a/src/Flf/Info.cc +++ b/src/Flf/Info.cc @@ -108,7 +108,15 @@ struct LatticeStatistics { u32 nFinalStates; Time minTime, maxTime; LatticeStatistics() - : nInputEpsilonArcs(0), nInputNonWordArcs(0), nInputWordArcs(0), nOutputEpsilonArcs(0), nOutputNonWordArcs(0), nOutputWordArcs(0), nFinalStates(0), minTime(Core::Type { : Precursor(f) { } StateMappedAutomaton(_ConstAutomatonRef f, const Fsa::StateMap& map) - : Precursor(f), map_(map) { + : Precursor(f), + map_(map) { } public: @@ -91,7 +92,8 @@ class MinimizeAutomaton : public StateMappedAutomaton { Fsa::StateId id_; ClassId class_; ClassEntry(Fsa::StateId id, ClassId c) - : id_(id), class_(c) { + : id_(id), + class_(c) { } bool operator<(const ClassEntry& e) const { return class_ < e.class_; @@ -111,7 +113,11 @@ class MinimizeAutomaton : public StateMappedAutomaton { HashTargetMappedAndLabels(_ConstAutomatonRef f, const ClassMap& classMap, const StatePotentials<_Weight>& statePotentials) - : transducer_(f->type() == Fsa::TypeTransducer), fsa_(f), classMap_(classMap), semiring_(f->semiring()), statePotentials_(statePotentials) {} + : transducer_(f->type() == Fsa::TypeTransducer), + fsa_(f), + classMap_(classMap), + semiring_(f->semiring()), + statePotentials_(statePotentials) {} size_t operator()(Fsa::StateId s) const { const _ConstStateRef sp = fsa_->getState(s); size_t key = 100003 * semiring_->hash(statePotentials_[s]) + sp->nArcs(); @@ -136,7 +142,10 @@ class MinimizeAutomaton : public StateMappedAutomaton { EqualTargetMappedAndLabels(_ConstAutomatonRef f, const ClassMap& classMap, const StatePotentials<_Weight>& statePotentials) - : fsa_(f), classMap_(classMap), semiring_(f->semiring()), statePotentials_(statePotentials) { + : fsa_(f), + classMap_(classMap), + semiring_(f->semiring()), + statePotentials_(statePotentials) { } bool operator()(Fsa::StateId as, Fsa::StateId bs) const { if (semiring_->compare(statePotentials_[as], diff --git a/src/Fsa/tOutput.cc b/src/Fsa/tOutput.cc index 2d4470e77..451dea39a 100644 --- a/src/Fsa/tOutput.cc +++ b/src/Fsa/tOutput.cc @@ -76,7 +76,10 @@ class WriteAttDfsState : public DfsState<_Automaton> { public: WriteAttDfsState(_ConstAutomatonRef f, std::ostream& o) - : Precursor(f), o_(o), input_(f->getInputAlphabet()), output_(f->getOutputAlphabet()) {} + : Precursor(f), + o_(o), + input_(f->getInputAlphabet()), + output_(f->getOutputAlphabet()) {} void discoverState(_ConstStateRef sp) { if (sp->isFinal()) { o_ << sp->id(); @@ -161,7 +164,8 @@ class WriteBinaryDfsState : public DfsState<_Automaton> { public: WriteBinaryDfsState(_ConstAutomatonRef f, Core::BinaryOutputStream& o) - : Precursor(f), o_(o) {} + : Precursor(f), + o_(o) {} void discoverState(_ConstStateRef sp) { Fsa::StateId idAndTags = sp->id() | sp->tags(); if (!(o_ << idAndTags)) @@ -256,7 +260,12 @@ class WriteLinearDfsState : public DfsState<_Automaton> { public: WriteLinearDfsState(_ConstAutomatonRef f, std::ostream& o, bool printAll = false) - : Precursor(f), first_(true), o_(o), input_(Precursor::fsa_->getInputAlphabet()), output_(Precursor::fsa_->getOutputAlphabet()), printAll_(printAll) {} + : Precursor(f), + first_(true), + o_(o), + input_(Precursor::fsa_->getInputAlphabet()), + output_(Precursor::fsa_->getOutputAlphabet()), + printAll_(printAll) {} void discoverState(_ConstStateRef sp) { for (typename _State::const_iterator a = sp->begin(); a != sp->end(); ++a) { std::string out; @@ -328,7 +337,8 @@ class WriteXmlDfsState : public DfsState<_Automaton> { public: WriteXmlDfsState(_ConstAutomatonRef f, Core::XmlWriter& o) - : Precursor(f), o_(o) {} + : Precursor(f), + o_(o) {} void discoverState(_ConstStateRef sp) { o_ << Core::XmlOpen("state") + Core::XmlAttribute("id", sp->id()); if (sp->isFinal()) @@ -433,7 +443,9 @@ class NodeTrWGWriter : public DfsState<_Automaton> { public: NodeTrWGWriter(_ConstAutomatonRef f, std::ostream& o) - : Precursor(f), o_(o), max_id_(0) {} + : Precursor(f), + o_(o), + max_id_(0) {} virtual void discoverState(_ConstStateRef sp) { if (max_id_ < sp->id()) @@ -492,7 +504,9 @@ class EdgeTrWGWriter : public DfsState<_Automaton> { public: EdgeTrWGWriter(_ConstAutomatonRef f, NodeTrWGWriter<_Automaton>& nw, std::ostream& o) - : Precursor(f), nw_(nw), o_(o) { + : Precursor(f), + nw_(nw), + o_(o) { inAlpha_ = f->getInputAlphabet(); } diff --git a/src/Fsa/tPrune.cc b/src/Fsa/tPrune.cc index 0958a4882..0e9920827 100644 --- a/src/Fsa/tPrune.cc +++ b/src/Fsa/tPrune.cc @@ -59,7 +59,14 @@ class PosteriorPruneAutomaton : public SlaveAutomaton<_Automaton>, public DfsSta public: PosteriorPruneAutomaton(_ConstAutomatonRef f, const _Weight& threshold, bool relative) - : Precursor(f), DfsState<_Automaton>(f), threshold_(threshold), fw_(new _StatePotentials), bw_(new _StatePotentials), forward_(*fw_), backward_(*bw_), relative_(relative) { + : Precursor(f), + DfsState<_Automaton>(f), + threshold_(threshold), + fw_(new _StatePotentials), + bw_(new _StatePotentials), + forward_(*fw_), + backward_(*bw_), + relative_(relative) { this->setProperties(Fsa::PropertyStorage | Fsa::PropertyCached, Fsa::PropertyNone); Fsa::StateId initial = f->initialStateId(); if (initial != Fsa::InvalidStateId) { @@ -70,7 +77,14 @@ class PosteriorPruneAutomaton : public SlaveAutomaton<_Automaton>, public DfsSta } PosteriorPruneAutomaton(_ConstAutomatonRef f, const _Weight& threshold, const _StatePotentials& fw, const _StatePotentials& bw, bool relative) - : Precursor(f), DfsState<_Automaton>(f), threshold_(threshold), fw_(0), bw_(0), forward_(fw), backward_(bw), relative_(relative) { + : Precursor(f), + DfsState<_Automaton>(f), + threshold_(threshold), + fw_(0), + bw_(0), + forward_(fw), + backward_(bw), + relative_(relative) { this->setProperties(Fsa::PropertyStorage | Fsa::PropertyCached, Fsa::PropertyNone); setMinWeight(threshold); } @@ -160,7 +174,9 @@ class SyncPruneAutomaton : public SlaveAutomaton<_Automaton> { public: SyncPruneAutomaton(_ConstAutomatonRef f, const _Weight& threshold) - : Precursor(f), threshold_(threshold), maxWeight_(Precursor::fsa_->semiring()->max()) { + : Precursor(f), + threshold_(threshold), + maxWeight_(Precursor::fsa_->semiring()->max()) { this->setProperties(Fsa::PropertyStorage | Fsa::PropertyCached, Fsa::PropertyNone); Fsa::StateId initial = f->initialStateId(); slice_.grow(initial, Fsa::InvalidStateId); diff --git a/src/Fsa/tRational.cc b/src/Fsa/tRational.cc index d77cb4b33..d7fda8566 100644 --- a/src/Fsa/tRational.cc +++ b/src/Fsa/tRational.cc @@ -58,7 +58,8 @@ class ClosureAutomaton : public ModifyAutomaton<_Automaton> { public: ClosureAutomaton(_ConstAutomatonRef f) - : Precursor(f), initial_(f->initialStateId()) { + : Precursor(f), + initial_(f->initialStateId()) { } /*! \todo existing loop weight should be collected with final weight. */ virtual void modifyState(_State* sp) const { @@ -98,7 +99,8 @@ class KleeneClosureAutomaton : public SlaveAutomaton<_Automaton> { public: KleeneClosureAutomaton(_ConstAutomatonRef f) - : Precursor(closure(f)), initialIsFinal_(false) { + : Precursor(closure(f)), + initialIsFinal_(false) { } virtual Fsa::StateId initialStateId() const { if (Precursor::fsa_->initialStateId() != Fsa::InvalidStateId) { @@ -283,35 +285,43 @@ class ConcatUnionAutomaton : public _Automaton { } virtual std::string name() const = 0; - virtual Fsa::Type type() const { + + virtual Fsa::Type type() const { return type_; } + virtual _ConstSemiringRef semiring() const { return semiring_; } + virtual Fsa::ConstAlphabetRef getInputAlphabet() const { return input_; } + virtual Fsa::ConstAlphabetRef getOutputAlphabet() const { return output_; } + virtual void dumpState(Fsa::StateId s, std::ostream& o) const { u32 k = subAutomaton(s); o << k << ","; fsa_[k]->dumpState(subStateId(s), o); - }; + } + virtual size_t getMemoryUsed() const { size_t memoryUsed = 0; for (typename Core::Vector<_ConstAutomatonRef>::const_iterator i = fsa_.begin(); i != fsa_.end(); ++i) memoryUsed += (*i)->getMemoryUsed(); return memoryUsed; } + virtual void dumpMemoryUsage(Core::XmlWriter& o) const { o << Core::XmlOpen(name()); for (typename Core::Vector<_ConstAutomatonRef>::const_iterator i = fsa_.begin(); i != fsa_.end(); ++i) (*i)->dumpMemoryUsage(o); o << Core::XmlClose(name()); } + virtual std::string describe() const { std::string result = name() + "("; for (size_t i = 0; i < fsa_.size(); ++i) { @@ -322,11 +332,13 @@ class ConcatUnionAutomaton : public _Automaton { result += ")"; return result; } + // State ids are interleaved. We reserve 0 for the initial state. /** Index of sub-automaton. */ u32 subAutomaton(Fsa::StateId s) const { return (s - 1) % fsa_.size(); } + /** Id of state in sub-automaton. */ Fsa::StateId subStateId(Fsa::StateId s) const { return (s - 1) / fsa_.size(); @@ -410,7 +422,8 @@ class UnionAutomaton : public ConcatUnionAutomaton<_Automaton> { public: UnionAutomaton(const Core::Vector<_ConstAutomatonRef>& fsa, const Core::Vector& initialWeights) - : Precursor(fsa), initialWeights_(initialWeights) { + : Precursor(fsa), + initialWeights_(initialWeights) { if (initialWeights_.empty()) initialWeights_.resize(fsa.size(), Precursor::semiring_->one()); else @@ -474,7 +487,8 @@ class UniteMapping : public Fsa::Mapping { public: UniteMapping(_ConstAutomatonRef f, u32 subAutomaton) - : uFsa_(dynamic_cast(f.get())), subAutomaton_(subAutomaton) { + : uFsa_(dynamic_cast(f.get())), + subAutomaton_(subAutomaton) { require(uFsa_); } virtual ~UniteMapping() {} @@ -709,7 +723,8 @@ class AllPrefixesAutomaton : public ModifyAutomaton<_Automaton> { typedef typename _Automaton::Arc _Arc; ProductiveStateFinder(_ConstAutomatonRef f, std::vector& productive_states) - : Precursor(f), productiveStates_(productive_states) {} + : Precursor(f), + productiveStates_(productive_states) {} virtual void discoverState(_ConstStateRef sp) { Fsa::StateTag tags = sp->tags(); @@ -792,7 +807,8 @@ class AllSuffixesAutomaton : public ModifyAutomaton<_Automaton> { typedef typename _Automaton::Arc _Arc; ReachableStateFinder(_ConstAutomatonRef f, std::vector& reachable_states) - : Precursor(f), reachableStates_(reachable_states) {} + : Precursor(f), + reachableStates_(reachable_states) {} virtual void discoverState(_ConstStateRef sp) { if (sp->id() >= reachableStates_.size()) { diff --git a/src/Fsa/tRealSemiring.hh b/src/Fsa/tRealSemiring.hh index 37304e755..c7da7a1b2 100644 --- a/src/Fsa/tRealSemiring.hh +++ b/src/Fsa/tRealSemiring.hh @@ -34,13 +34,17 @@ protected: public: LogSemiring(); LogSemiring(s32 tolerance); + virtual std::string name() const; - _Weight create() const { + + _Weight create() const { return _Weight(); } + _Weight clone(const _Weight& a) const { return _Weight(_Type(a)); } + virtual _Weight invalid() const; virtual _Weight zero() const; virtual _Weight one() const; @@ -49,9 +53,11 @@ public: virtual _Accumulator* getExtender(const _Weight& initial) const; virtual _Weight collect(const _Weight& a, const _Weight& b) const; virtual _Weight invCollect(const _Weight& a, const _Weight& b) const; - virtual bool hasInvCollect() const { + + virtual bool hasInvCollect() const { return true; } + virtual _Accumulator* getCollector(const _Weight& initial) const; virtual _Weight invert(const _Weight& a) const; virtual int compare(const _Weight& a, const _Weight& b) const; @@ -81,15 +87,19 @@ public: public: TropicalSemiring(); TropicalSemiring(s32 tolerance); + virtual std::string name() const; virtual _Weight collect(const _Weight& a, const _Weight& b) const; - virtual _Weight invCollect(const _Weight& a, const _Weight& b) const { + + virtual _Weight invCollect(const _Weight& a, const _Weight& b) const { std::cerr << "method \"invCollect\" is not supported" << std::endl; return this->invalid(); } + virtual bool hasInvCollect() const { return false; } + virtual _Accumulator* getCollector(const _Weight& initial) const; }; diff --git a/src/Fsa/tResources.hh b/src/Fsa/tResources.hh index 1bd6f5baf..5b17f7eed 100644 --- a/src/Fsa/tResources.hh +++ b/src/Fsa/tResources.hh @@ -44,9 +44,15 @@ public: Reader reader; Writer writer; Format(const std::string& name, const std::string& desc) - : name(name), desc(desc), reader(0), writer(0) {} + : name(name), + desc(desc), + reader(0), + writer(0) {} Format(const std::string& name, const std::string& desc, Reader reader, Writer writer) - : name(name), desc(desc), reader(reader), writer(writer) {} + : name(name), + desc(desc), + reader(reader), + writer(writer) {} }; private: diff --git a/src/Fsa/tSemiring.hh b/src/Fsa/tSemiring.hh index 7761a5e0a..6e68a942d 100644 --- a/src/Fsa/tSemiring.hh +++ b/src/Fsa/tSemiring.hh @@ -69,13 +69,16 @@ public: virtual _Weight max() const = 0; virtual _Weight extend(const _Weight& a, const _Weight& b) const = 0; virtual _Weight collect(const _Weight& a, const _Weight& b) const = 0; - virtual _Weight invCollect(const _Weight& a, const _Weight& b) const { + + virtual _Weight invCollect(const _Weight& a, const _Weight& b) const { std::cerr << "method \"invCollect\" is not supported" << std::endl; return invalid(); } + virtual bool hasInvCollect() const { return false; } + virtual _Weight invert(const _Weight& a) const = 0; virtual int compare(const _Weight& a, const _Weight& b) const = 0; virtual size_t hash(const _Weight& a) const = 0; @@ -109,7 +112,9 @@ protected: Function f; _Weight w; AnchoredAccumulator(const Self* self, Function f, const _Weight& init) - : self(self), f(f), w(init) { + : self(self), + f(f), + w(init) { self->acquireReference(); } virtual ~AnchoredAccumulator() { diff --git a/src/Fsa/tSort.cc b/src/Fsa/tSort.cc index 124aee839..f1fd01343 100644 --- a/src/Fsa/tSort.cc +++ b/src/Fsa/tSort.cc @@ -239,7 +239,9 @@ class TopologicallySortDfsState : public DfsState<_Automaton> { public: TopologicallySortDfsState(_ConstAutomatonRef f) - : Precursor(f), time_(0), isCyclic_(false) {} + : Precursor(f), + time_(0), + isCyclic_(false) {} virtual void finishState(Fsa::StateId s) { map_.grow(s, Fsa::InvalidStateId); map_[s] = time_++; diff --git a/src/Fsa/tSssp.cc b/src/Fsa/tSssp.cc index 381e85bdc..00d0032c0 100644 --- a/src/Fsa/tSssp.cc +++ b/src/Fsa/tSssp.cc @@ -35,9 +35,13 @@ class SsspQueue { public: SsspQueue() - : head_(Fsa::InvalidStateId), tail_(Fsa::InvalidStateId), n_(0) {} + : head_(Fsa::InvalidStateId), + tail_(Fsa::InvalidStateId), + n_(0) {} SsspQueue(Fsa::StateId maxStateId) - : head_(Fsa::InvalidStateId), tail_(Fsa::InvalidStateId), n_(0) { + : head_(Fsa::InvalidStateId), + tail_(Fsa::InvalidStateId), + n_(0) { next_.grow(maxStateId, Fsa::InvalidStateId); } @@ -107,7 +111,8 @@ class TopologicalSsspQueue : public SsspQueue { public: TopologicalSsspQueue(const Fsa::StateMap& s2t) - : SsspQueue(s2t.size() - 1), s2t_(s2t) {} + : SsspQueue(s2t.size() - 1), + s2t_(s2t) {} virtual ~TopologicalSsspQueue() {} @@ -419,7 +424,8 @@ class PosteriorAutomaton : public ModifyAutomaton<_Automaton> { totalInv_ = this->semiring()->invert(backwardPotentials_[f->initialStateId()]); } PosteriorAutomaton(_ConstAutomatonRef f, const _StatePotentials& forward) - : Precursor(f), forwardPotentials_(forward) { + : Precursor(f), + forwardPotentials_(forward) { backwardPotentials_ = sssp<_Automaton>(transpose<_Automaton>(f)); totalInv_ = this->semiring()->invert(backwardPotentials_[f->initialStateId()]); } diff --git a/src/Fsa/tSssp4SpecialSymbols.cc b/src/Fsa/tSssp4SpecialSymbols.cc index ae0ee8717..ac8b973d3 100644 --- a/src/Fsa/tSssp4SpecialSymbols.cc +++ b/src/Fsa/tSssp4SpecialSymbols.cc @@ -67,7 +67,7 @@ Fsa::ConstSemiring64Ref convertSemiring(Fsa::ConstSemir } template -struct byInputAndOutput : public std::binary_function<_leftArc, _rightArc, bool> { +struct byInputAndOutput { static bool cmp(const _leftArc& a, const _rightArc& b) { if (a.input() < b.input()) return true; @@ -183,10 +183,11 @@ class SsspBackward4SpecialSymbols : public DfsState<_Automaton> { inline static _Weight failArc(const _Arc* arc, DirectHits& directHits, _Accumulator* collector, const bool firstLevel, _SsspBackward4SpecialSymbols* owner) { Fsa::StateId ts = arc->target(); _Weight failWeight, update; + failWeight = owner->processState(ts, directHits); - update = owner->semiring_->extend( - sssp4SpecialSymbolsHelper::convertWeight<_AutomataWeight, _Weight>(arc->weight()), - failWeight); + update = owner->semiring_->extend(sssp4SpecialSymbolsHelper::convertWeight<_AutomataWeight, _Weight>(arc->weight()), + failWeight); + collector->feed(update); return failWeight; } diff --git a/src/Fsa/tStatic.hh b/src/Fsa/tStatic.hh index 3d7a3c265..c7f382f65 100644 --- a/src/Fsa/tStatic.hh +++ b/src/Fsa/tStatic.hh @@ -41,14 +41,21 @@ private: public: StaticAutomaton(Fsa::Type type = Fsa::TypeUnknown) - : Precursor(type), desc_("static"), memoryUsed_(0) {} + : Precursor(type), + desc_("static"), + memoryUsed_(0) {} StaticAutomaton(const std::string& desc, Fsa::Type type = Fsa::TypeUnknown) - : Precursor(type), desc_(desc), memoryUsed_(0) {} + : Precursor(type), + desc_(desc), + memoryUsed_(0) {} virtual ~StaticAutomaton(); + virtual void clear(); - void setDescription(const std::string& desc) { + + void setDescription(const std::string& desc) { desc_ = desc; } + virtual bool hasState(Fsa::StateId sid) const; _State* newState(Fsa::StateId tags = 0); _State* newState(Fsa::StateId tags, const _Weight& finalWeight); @@ -56,27 +63,35 @@ public: void setStateFinal(_State*, const _Weight& finalWeight); void setStateFinal(_State*); _StateRef state(Fsa::StateId s); - _State* fastState(Fsa::StateId s) { + + _State* fastState(Fsa::StateId s) { return states_[s].get(); } + const _State* fastState(Fsa::StateId s) const { return states_[s].get(); } + virtual void setState(_State* sp); virtual void deleteState(Fsa::StateId); virtual _ConstStateRef getState(Fsa::StateId s) const; virtual void normalize(); - virtual Fsa::StateId maxStateId() const { + + virtual Fsa::StateId maxStateId() const { return states_.size() - 1; } + virtual Fsa::StateId size() const { return states_.size(); } - virtual size_t getMemoryUsed() const; - virtual void dumpMemoryUsage(Core::XmlWriter& o) const; + + virtual size_t getMemoryUsed() const; + virtual void dumpMemoryUsage(Core::XmlWriter& o) const; + virtual std::string describe() const { return desc_; } + void compact(Fsa::StateMap& mapping); }; } // namespace Ftl diff --git a/src/Fsa/tStaticAlgorithm.cc b/src/Fsa/tStaticAlgorithm.cc index b2a19788f..3b70836e7 100644 --- a/src/Fsa/tStaticAlgorithm.cc +++ b/src/Fsa/tStaticAlgorithm.cc @@ -62,7 +62,8 @@ class IsInvalidArc { public: IsInvalidArc(Fsa::ConstSemiringRef sr) - : invalid_(sr->invalid()), zero_(sr->zero()) {} + : invalid_(sr->invalid()), + zero_(sr->zero()) {} bool operator()(const Fsa::Arc& a) { return (a.weight() == invalid_) or (a.weight() == zero_); } diff --git a/src/Fsa/tStorage.hh b/src/Fsa/tStorage.hh index 2cb694889..7bca234ef 100644 --- a/src/Fsa/tStorage.hh +++ b/src/Fsa/tStorage.hh @@ -40,17 +40,21 @@ protected: public: StorageAutomaton(Fsa::Type type = Fsa::TypeUnknown); - virtual Fsa::Type type() const; - virtual void setType(Fsa::Type type); - virtual void addProperties(Fsa::Property properties) const; - virtual void setProperties(Fsa::Property knownProperties, Fsa::Property properties) const; - virtual void unsetProperties(Fsa::Property unknownProperties) const; + virtual Fsa::Type type() const; + virtual void setType(Fsa::Type type); + virtual void addProperties(Fsa::Property properties) const; + virtual void setProperties(Fsa::Property knownProperties, Fsa::Property properties) const; + virtual void unsetProperties(Fsa::Property unknownProperties) const; + virtual _ConstSemiringRef semiring() const; - virtual void setSemiring(_ConstSemiringRef semiring) { + + virtual void setSemiring(_ConstSemiringRef semiring) { semiring_ = semiring; } + virtual Fsa::StateId initialStateId() const; - virtual void setInitialStateId(Fsa::StateId initial) { + + virtual void setInitialStateId(Fsa::StateId initial) { initial_ = initial; } diff --git a/src/Fsa/tStorageXml.hh b/src/Fsa/tStorageXml.hh index 7d3b5740a..ad1f6f422 100644 --- a/src/Fsa/tStorageXml.hh +++ b/src/Fsa/tStorageXml.hh @@ -213,7 +213,9 @@ private: public: StorageAutomatonXmlParser(const _Resources& resources, _StorageAutomaton* fsa) - : Precursor(resources.getConfiguration()), resources_(resources), fsa_(fsa) { + : Precursor(resources.getConfiguration()), + resources_(resources), + fsa_(fsa) { fsa_->setType(Fsa::TypeTransducer); Core::XmlMixedElement* arcElement = new Core::XmlMixedElementRelay("arc", this, startHandler(&Self::startArc), 0, 0, XML_CHILD(new Core::XmlMixedElementRelay("in", this, startHandler(&Self::startIn), diff --git a/src/Lattice/Accumulator.hh b/src/Lattice/Accumulator.hh index 79e768cc3..3a18490e3 100644 --- a/src/Lattice/Accumulator.hh +++ b/src/Lattice/Accumulator.hh @@ -37,10 +37,10 @@ protected: Trainer* trainer_; protected: - virtual void accumulate(Core::Ref f, Mm::MixtureIndex m, Mm::Weight w) { + virtual void accumulate(Mm::Feature::VectorRef f, Mm::MixtureIndex m, Mm::Weight w) { defect(); } - virtual void accumulate(Core::Ref f, const PosteriorsAndDensities& p) { + virtual void accumulate(Mm::Feature::VectorRef f, const PosteriorsAndDensities& p) { defect(); } /* @@ -87,9 +87,11 @@ protected: protected: virtual const Alignment* getAlignment(Fsa::ConstStateRef from, const Fsa::Arc& a); - virtual void process(TimeframeIndex t, Mm::MixtureIndex m, Mm::Weight w) { + + virtual void process(TimeframeIndex t, Mm::MixtureIndex m, Mm::Weight w) { this->accumulate((*accumulationFeatures_)[t]->mainStream(), m, w); } + virtual void reset() {} public: @@ -100,7 +102,8 @@ public: virtual ~AcousticAccumulator() {} virtual void discoverState(Fsa::ConstStateRef sp); - void setAccumulationFeatures(ConstSegmentwiseFeaturesRef accumulationFeatures) { + + void setAccumulationFeatures(ConstSegmentwiseFeaturesRef accumulationFeatures) { accumulationFeatures_ = accumulationFeatures; } }; @@ -112,7 +115,8 @@ struct Key { Speech::TimeframeIndex t; Mm::MixtureIndex m; Key(Speech::TimeframeIndex _t, Mm::MixtureIndex _m) - : t(_t), m(_m) {} + : t(_t), + m(_m) {} }; struct KeyHash { size_t operator()(const Key& k) const { diff --git a/src/Lattice/Accuracy.cc b/src/Lattice/Accuracy.cc index 5a084e7c8..01d82ecb6 100644 --- a/src/Lattice/Accuracy.cc +++ b/src/Lattice/Accuracy.cc @@ -309,7 +309,8 @@ class ApproximateAccuracyAutomaton : public ModifyWordLattice { struct TimeInterval { Speech::TimeframeIndex startTime, endTime; TimeInterval(Speech::TimeframeIndex _startTime, Speech::TimeframeIndex _endTime) - : startTime(_startTime), endTime(_endTime) {} + : startTime(_startTime), + endTime(_endTime) {} }; struct TimeIntervalHash { size_t operator()(const TimeInterval& i) const { @@ -324,7 +325,8 @@ class ApproximateAccuracyAutomaton : public ModifyWordLattice { struct Hypothesis : public TimeInterval { Fsa::LabelId label; Hypothesis(Fsa::LabelId _label, Speech::TimeframeIndex _startTime, Speech::TimeframeIndex _endTime) - : TimeInterval(_startTime, _endTime), label(_label) {} + : TimeInterval(_startTime, _endTime), + label(_label) {} }; typedef std::unordered_set TimeIntervals; typedef std::unordered_map States; @@ -1052,11 +1054,12 @@ class WeightedFramePhoneAccuracyAutomaton : public ApproximatePhoneAccuracyAutom class SetDerivativesDfsState : public DfsState { private: // without constant factor @param beta - struct diffSigmoid : public std::unary_function { + struct diffSigmoid { const f64 beta, marginFactor; static constexpr s64 tol = 45035996274LL; // = Core::differenceUlp(1, 1.00001) static constexpr f64 inf = 1e9; - f32 operator()(f64 x) const { + + f32 operator()(f64 x) const { require(!Core::isSignificantlyLessUlp(x, 0, tol) && !Core::isSignificantlyLessUlp(1, x, tol)); if (beta != 1) { if (Core::isAlmostEqualUlp(x, 0, tol) || Core::isAlmostEqualUlp(x, 1, tol)) { @@ -1072,8 +1075,10 @@ class WeightedFramePhoneAccuracyAutomaton : public ApproximatePhoneAccuracyAutom return 1; } } + diffSigmoid(f64 _beta, f64 _margin) - : beta(_beta), marginFactor(exp(beta * _margin)) { + : beta(_beta), + marginFactor(exp(beta * _margin)) { require(beta > 0); } }; diff --git a/src/Lattice/Archive.cc b/src/Lattice/Archive.cc index c43ddd6a0..8d5ebb542 100644 --- a/src/Lattice/Archive.cc +++ b/src/Lattice/Archive.cc @@ -43,7 +43,7 @@ enum LatticeFormat { formatSourceFile, formatWolfgang }; -} +} // namespace Lattice namespace Lattice { class FsaArchiveReader; @@ -85,7 +85,7 @@ enum Alphabet { evaluationTokenAlphabet, noLexiconCheck }; -} +} // namespace Lattice const Core::Choice Archive::alphabetChoice( "unknown", unknownAlphabet, diff --git a/src/Lattice/Basic.cc b/src/Lattice/Basic.cc index 192b476e6..1e62db9b7 100644 --- a/src/Lattice/Basic.cc +++ b/src/Lattice/Basic.cc @@ -59,7 +59,8 @@ struct Cutter { Fsa::StateId initial; Fsa::Weight finalWeight; Cutter(Fsa::StateId _initial, Fsa::Weight _finalWeight) - : initial(_initial), finalWeight(_finalWeight) {} + : initial(_initial), + finalWeight(_finalWeight) {} Fsa::ConstAutomatonRef modify(Fsa::ConstAutomatonRef fsa) { return Fsa::partial(fsa, initial, finalWeight); } @@ -165,7 +166,9 @@ class LatticeCounts : public DfsState { public: LatticeCounts(ConstWordLatticeRef lattice, Predicate p) - : DfsState(lattice), pred(p), nArcs(0) {} + : DfsState(lattice), + pred(p), + nArcs(0) {} virtual void discoverState(Fsa::ConstStateRef sp) { nArcs += std::count_if(sp->begin(), sp->end(), pred); diff --git a/src/Lattice/Best.hh b/src/Lattice/Best.hh index 8ddae5039..edcd6aca4 100644 --- a/src/Lattice/Best.hh +++ b/src/Lattice/Best.hh @@ -50,28 +50,37 @@ public: virtual ~NBestListExtractor(); ConstWordLatticeRef getNBestList(ConstWordLatticeRef); - void initialize(Bliss::LexiconRef); - void setNumberOfHypotheses(u32 nHypotheses) { + + void initialize(Bliss::LexiconRef); + + void setNumberOfHypotheses(u32 nHypotheses) { targetNHypotheses_ = nHypotheses; } + void setMinPruningThreshold(f32 minThreshold) { minThreshold_ = minThreshold; } + void setMaxPruningThreshold(f32 maxThreshold) { maxThreshold_ = maxThreshold; } + void setPruningIncrement(f32 thresholdIncrement) { thresholdIncrement_ = thresholdIncrement; } + void setWorkOnOutput(bool workOnOutput) { workOnOutput_ = workOnOutput; } + void setLatticeIsDeterministic(bool isDeterministic) { latticeIsDeterministic_ = isDeterministic; } + void setHasFailArcs(bool hasFailArcs) { hasFailArcs_ = hasFailArcs; } + void setNormalize(bool normalize) { normalize_ = normalize; } diff --git a/src/Lattice/Lattice.cc b/src/Lattice/Lattice.cc index 4b49d9ba9..65f86b903 100644 --- a/src/Lattice/Lattice.cc +++ b/src/Lattice/Lattice.cc @@ -325,7 +325,7 @@ class TimeConditionedWordLattice : public Fsa::SlaveAutomaton, */ mutable Core::Ref timeConditionedWordBoundaries_; - struct ByInputAndTarget : public std::binary_function { + struct ByInputAndTarget { bool operator()(const Fsa::Arc& a, const Fsa::Arc& b) const { return ((a.input() < b.input()) || ((a.input() == b.input()) && (a.target() < b.target()))); } diff --git a/src/Lattice/Lattice.hh b/src/Lattice/Lattice.hh index 99f2f4854..299bc9a10 100644 --- a/src/Lattice/Lattice.hh +++ b/src/Lattice/Lattice.hh @@ -35,9 +35,11 @@ public: : final(Bliss::Phoneme::term), initial(Bliss::Phoneme::term) {} Transit(const std::pair& transit) - : final(transit.first), initial(transit.second) {} + : final(transit.first), + initial(transit.second) {} Transit(Bliss::Phoneme::Id _final, Bliss::Phoneme::Id _initial) - : final(_final), initial(_initial) {} + : final(_final), + initial(_initial) {} bool operator==(const Transit& rhs) const { return final == rhs.final && initial == rhs.initial; @@ -54,10 +56,12 @@ public: WordBoundary(Speech::TimeframeIndex time = Speech::InvalidTimeframeIndex) : time_(time) {} WordBoundary(Speech::TimeframeIndex time, const Transit& transit) - : time_(time), transit_(transit) {} + : time_(time), + transit_(transit) {} WordBoundary(Speech::TimeframeIndex time, const std::pair& transit) - : time_(time), transit_(transit) {} + : time_(time), + transit_(transit) {} void setTime(Speech::TimeframeIndex time) { time_ = time; @@ -244,17 +248,22 @@ public: void setWordBoundaries(Core::Ref wordBoundaries) { wordBoundaries_ = wordBoundaries; } + Core::Ref wordBoundaries() const { return wordBoundaries_; } + const WordBoundary& wordBoundary(Fsa::StateId id) const { return (*wordBoundaries_)[id]; } + Speech::TimeframeIndex time(Fsa::StateId id) const { return (*wordBoundaries_)[id].time(); } + Speech::TimeframeIndex maximumTime() const; - bool hasPart(const std::string& name) const { + + bool hasPart(const std::string& name) const { return parts_[name] != Core::Choice::IllegalValue; } }; diff --git a/src/Lattice/Makefile b/src/Lattice/Makefile index 569258d84..fb1a54fee 100644 --- a/src/Lattice/Makefile +++ b/src/Lattice/Makefile @@ -7,8 +7,7 @@ include $(TOPDIR)/Makefile.cfg # ----------------------------------------------------------------------------- SUBDIRS = -#TARGETS = libSprintLattice.$(a) check$(exe) -TARGETS = libSprintLattice.$(a) +TARGETS = libSprintLattice.$(a) check$(exe) LIBSPRINTLATTICE_O = \ $(OBJDIR)/Archive.o \ @@ -63,7 +62,20 @@ LIBSPRINTLATTICE_O += $(OBJDIR)/HtkWriter.o endif ifdef MODULE_MATH_NR -CHECK_O += ../Math/Nr/libSprintMathNr$(a) +CHECK_O += ../Math/Nr/libSprintMathNr.$(a) +endif + +ifdef MODULE_NN_SEQUENCE_TRAINING +CHECK_O += ../Nn/libSprintNn.$(a) +endif +ifdef MODULE_PYTHON +CHECK_O += ../Python/libSprintPython.$(a) +endif + +ifdef MODULE_TENSORFLOW +CHECK_O += ../Tensorflow/libSprintTensorflow.$(a) +CXXFLAGS += $(TF_CXXFLAGS) +LDFLAGS += $(TF_LDFLAGS) endif # ----------------------------------------------------------------------------- diff --git a/src/Lattice/Merge.cc b/src/Lattice/Merge.cc index 5e08ac483..728207f3c 100644 --- a/src/Lattice/Merge.cc +++ b/src/Lattice/Merge.cc @@ -67,14 +67,17 @@ struct ContextEquality { struct ArcWithContext : public Fsa::Arc { Context rightContext_; - bool operator==(const ArcWithContext& rhs) const { + + bool operator==(const ArcWithContext& rhs) const { return input_ == rhs.input_ && output_ == rhs.output_ && rightContext_ == rhs.rightContext_; } + ArcWithContext(Fsa::StateId target, Fsa::Weight weight, Fsa::LabelId input, Fsa::LabelId output, const Context& rightContext) : Fsa::Arc(target, weight, input, output), rightContext_(rightContext) {} + const Context& rightContext() const { return rightContext_; } @@ -235,10 +238,13 @@ class MergeWordLattice : public Fsa::SlaveAutomaton { virtual Fsa::StateId initialStateId() const { return 0; } + virtual Fsa::ConstStateRef getState(Fsa::StateId s) const; - virtual std::string describe() const { + + virtual std::string describe() const { return Core::form("merge(%s)", fsa_->describe().c_str()); } + Core::Ref wordBoundaries() const { return wordBoundaries_; } @@ -345,7 +351,9 @@ class TurnOffCompetingHypothesesLattice : public ModifyWordLattice { InitializeCorrectHypothesesDfsState( Fsa::ConstAutomatonRef fsa, Fsa::ConstMappingRef mapping, CorrectHypotheses& hypotheses) - : Fsa::DfsState(fsa), mapping_(mapping), hypotheses_(hypotheses) {} + : Fsa::DfsState(fsa), + mapping_(mapping), + hypotheses_(hypotheses) {} virtual ~InitializeCorrectHypothesesDfsState() {} diff --git a/src/Lattice/TimeframeError.cc b/src/Lattice/TimeframeError.cc index 128dab24a..4b7628685 100644 --- a/src/Lattice/TimeframeError.cc +++ b/src/Lattice/TimeframeError.cc @@ -54,7 +54,8 @@ class LemmaMapping : public Mapping { public: LemmaMapping(const ShortPauses& shortPauses, Core::Ref alphabet) - : Mapping(shortPauses), alphabet_(alphabet) {} + : Mapping(shortPauses), + alphabet_(alphabet) {} virtual ~LemmaMapping() {} virtual Fsa::LabelId map(Fsa::LabelId pronId) const { diff --git a/src/Lattice/WeightedAccumulator.hh b/src/Lattice/WeightedAccumulator.hh index 627f9d4f6..da0611ac5 100644 --- a/src/Lattice/WeightedAccumulator.hh +++ b/src/Lattice/WeightedAccumulator.hh @@ -93,8 +93,8 @@ protected: Core::Ref posteriorFeatureScorer_; protected: - virtual void accumulate(Core::Ref f, Mm::MixtureIndex m, Mm::Weight w) {} - virtual void accumulate(Core::Ref f, const PosteriorsAndDensities& p) {} + virtual void accumulate(Mm::Feature::VectorRef f, Mm::MixtureIndex m, Mm::Weight w) {} + virtual void accumulate(Mm::Feature::VectorRef f, const PosteriorsAndDensities& p) {} virtual void accumulate(Core::Ref sf, Mm::MixtureIndex m, Mm::Weight w) {} public: @@ -102,8 +102,10 @@ public: typename Precursor::AlignmentGeneratorRef, Trainer*, Mm::Weight, Core::Ref); virtual ~DensityCachedAcousticAccumulator() {} + virtual void finish(); - void setFeatureScorer(Core::Ref fs) { + + void setFeatureScorer(Core::Ref fs) { posteriorFeatureScorer_ = fs; } }; @@ -130,8 +132,10 @@ public: Trainer*, Mm::Weight, Core::Ref); virtual ~TdpAccumulator() {} + virtual void discoverState(Fsa::ConstStateRef sp); - void setTransitionFeatures(Core::Ref transitions) { + + void setTransitionFeatures(Core::Ref transitions) { transitions_ = transitions; } }; @@ -156,10 +160,13 @@ protected: public: LmAccumulator(Trainer*, Mm::Weight, Core::Ref); virtual ~LmAccumulator() {} + virtual void discoverState(Fsa::ConstStateRef sp); - void setMgramFeatures(Core::Ref mgrams) { + + void setMgramFeatures(Core::Ref mgrams) { mgrams_ = mgrams; } + virtual void setFsa(Fsa::ConstAutomatonRef); }; diff --git a/src/Lattice/check.cc b/src/Lattice/check.cc index 4f95f7a31..1d2959223 100644 --- a/src/Lattice/check.cc +++ b/src/Lattice/check.cc @@ -52,7 +52,7 @@ class TestApplication : public Application { continue; if (it.name() == Lattice::Archive::latticeConfigFilename) continue; - if (it.name() == Fsa::ArchiveReader::alphabetFilename) + if (it.name() == Fsa::Archive::paramAlphabetFilename(config)) continue; log("read \"%s\"", it.name().c_str()); Fsa::ConstAutomatonRef f = @@ -68,6 +68,6 @@ class TestApplication : public Application { delete archiveReader; return 0; } -} app; // <- You have to create ONE instance of the application +}; -APPLICATION +APPLICATION(TestApplication) diff --git a/src/Lm/AbstractStateManager.hh b/src/Lm/AbstractStateManager.hh new file mode 100644 index 000000000..d64dd0ab2 --- /dev/null +++ b/src/Lm/AbstractStateManager.hh @@ -0,0 +1,71 @@ +#ifndef _LM_ABSTRACT_STATE_MANAGER_HH +#define _LM_ABSTRACT_STATE_MANAGER_HH + +#include +#include +#include + +#include +#ifdef MODULE_LM_ONNX +#include +#endif +#ifdef MODULE_LM_TFRNN +#include +#include +#endif + +#include "CompressedVector.hh" +#ifdef MODULE_LM_ONNX +#include +#endif + +namespace Lm { + +template +class AbstractStateManager : public Core::Component { +public: + using Precursor = Core::Component; + using FeedDict = std::vector>; + using TargetList = std::vector; + using StateVariables = std::vector; + using HistoryState = std::vector>; + + AbstractStateManager(Core::Configuration const& config); + virtual ~AbstractStateManager() = default; + + virtual bool requiresAllParentStates() const; + + virtual HistoryState initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory) = 0; + + virtual void mergeStates(StateVariables const& vars, + std::vector& prefix_lengths, + std::vector const& prefix_states, + FeedDict& feed_dict, + TargetList& targets) = 0; + + virtual std::vector splitStates(StateVariables const& vars, + std::vector& suffix_lengths, + std::vector const& state_tensors, + CompressedVectorFactory const& vector_factory) = 0; +}; + +#ifdef MODULE_LM_ONNX +using OnnxStateManager = AbstractStateManager; +#endif +#ifdef MODULE_LM_TFRNN +using TFStateManager = AbstractStateManager; +#endif + +template +inline bool AbstractStateManager::requiresAllParentStates() const { + return false; +} + +template +inline AbstractStateManager::AbstractStateManager(Core::Configuration const& config) + : Precursor(config) { +} + +} // namespace Lm + +#endif // _LM_ABSTRACT_STATE_MANAGER_HH diff --git a/src/Lm/ClassLm.cc b/src/Lm/ClassLm.cc index cd65d5b6c..93b425c60 100644 --- a/src/Lm/ClassLm.cc +++ b/src/Lm/ClassLm.cc @@ -39,12 +39,14 @@ class ClassTokenAlphabet : public Bliss::TokenAlphabet { public: ClassTokenAlphabet(ConstClassMappingRef mapping) - : Bliss::TokenAlphabet(mapping->tokenInventory()), mapping_(mapping) {} + : Bliss::TokenAlphabet(mapping->tokenInventory()), + mapping_(mapping) {} virtual ~ClassTokenAlphabet() {} }; ClassMapping::ClassMapping(const Core::Configuration& config, Bliss::LexiconRef lexicon) - : Core::Component(config), lexicon_(lexicon) { + : Core::Component(config), + lexicon_(lexicon) { tokenAlphabet_ = Fsa::ConstAlphabetRef(new ClassTokenAlphabet(Core::ref(this))); } diff --git a/src/Lm/CombineLm.cc b/src/Lm/CombineLm.cc index ef43e0868..647cca944 100644 --- a/src/Lm/CombineLm.cc +++ b/src/Lm/CombineLm.cc @@ -22,7 +22,8 @@ namespace { class CombineHistoryManager : public Lm::HistoryManager { public: CombineHistoryManager(size_t numLms) - : Lm::HistoryManager(), numLms_(numLms) { + : Lm::HistoryManager(), + numLms_(numLms) { } virtual ~CombineHistoryManager() = default; @@ -90,7 +91,14 @@ Core::ParameterFloat CombineLanguageModel::paramSkipThreshold( "skip-threshold", "if this LM's (unscaled) score is greater than this threshold successive LMs are not evaluated", std::numeric_limits::max()); CombineLanguageModel::CombineLanguageModel(Core::Configuration const& c, Bliss::LexiconRef l) - : Core::Component(c), CombineLanguageModel::Precursor(c, l), lms_(), unscaled_lms_(), linear_combination_(paramLinearCombination(c)), lookahead_lm_(paramLookaheadLM(config)), recombination_lm_(paramRecombinationLM(config)) { + : Core::Component(c), + CombineLanguageModel::Precursor(c, l), + lms_(), + unscaled_lms_(), + linear_combination_(paramLinearCombination(c)), + lookahead_lm_(paramLookaheadLM(config)), + recombination_lm_(paramRecombinationLM(config)), + staticRequestSize_(0) { size_t num_lms = paramNumLms(c); for (size_t i = 0ul; i < num_lms; i++) { Core::Configuration sub_config = select(std::string("lm-") + std::to_string(i + 1)); @@ -98,6 +106,7 @@ CombineLanguageModel::CombineLanguageModel(Core::Configuration const& c, Bliss:: unscaled_lms_.push_back(lms_.back()->unscaled()); ssa_lms_.push_back(dynamic_cast(unscaled_lms_.back().get())); skip_thresholds_.push_back(paramSkipThreshold(sub_config)); + lmIds_.push_back(i); } historyManager_ = new CombineHistoryManager(num_lms); } @@ -162,48 +171,37 @@ History CombineLanguageModel::reducedHistory(History const& history, u32 limit) return h; } -Score CombineLanguageModel::score(History const& history, Token w) const { +History CombineLanguageModel::reduceHistoryByN(History const& history, u32 n) const { + require(history.isManagedBy(historyManager_)); + History const* prev_hist = reinterpret_cast(history.handle()); + History* new_hist = new History[lms_.size()]; + for (size_t i = 0ul; i < lms_.size(); i++) { + new_hist[i] = lms_[i]->reduceHistoryByN(prev_hist[i], n); + } + History h = this->history(new_hist); + delete[] new_hist; + return h; +} + +std::string CombineLanguageModel::formatHistory(const History& h) const { + const History* hist = reinterpret_cast(h.handle()); + + std::stringstream ss; + ss << "CombinedHistory<"; + for (size_t i = 0ul; i < lms_.size(); i++) { + ss << " h" << i << ": " << unscaled_lms_[i]->formatHistory(hist[i]); + } + ss << " >"; + return ss.str(); +} + +Score CombineLanguageModel::score(const History& history, Token w) const { require(history.isManagedBy(historyManager_)); - History const* hist = reinterpret_cast(history.handle()); - Score prev_score = 0.0; - bool override_score = false; if (linear_combination_) { - Score s(std::numeric_limits::infinity()); - for (size_t i = 0ul; i < lms_.size(); i++) { - Score raw_score = 0.0; - if (not override_score) { - raw_score = unscaled_lms_[i]->score(hist[i], w); - prev_score = raw_score; - override_score |= raw_score >= skip_thresholds_[i]; - } - else { - raw_score = prev_score; - if (unscaled_lms_[i]->scoreCached(history, w)) { - raw_score = unscaled_lms_[i]->score(hist[i], w); - } - } - s = Math::scoreSum(s, raw_score - std::log(lms_[i]->scale())); - } - return s; + return score_(history, w, lmIds_); } else { - Score s(0.0); - for (size_t i = 0ul; i < lms_.size(); i++) { - Score raw_score = unscaled_lms_[i]->score(hist[i], w); - if (not override_score) { - raw_score = unscaled_lms_[i]->score(hist[i], w); - prev_score = raw_score; - override_score |= raw_score >= skip_thresholds_[i]; - } - else { - raw_score = prev_score; - if (unscaled_lms_[i]->scoreCached(history, w)) { - raw_score = unscaled_lms_[i]->score(hist[i], w); - } - } - s += raw_score * lms_[i]->scale(); - } - return s; + return score_(history, w, lmIds_); } } @@ -226,6 +224,127 @@ Score CombineLanguageModel::sentenceEndScore(const History& history) const { } } +void CombineLanguageModel::getBatch(const History& h, const CompiledBatchRequest* cbr, std::vector& result) const { + if (cacheHist_.empty() || cacheScores_.empty() || !matchCacheHistory(h)) { + Precursor::getBatch(h, cbr, result); + return; + } + + // apply update on partial sparse LMs' tokens only, others are cached and operated in same scheme + require(h.isManagedBy(historyManager_)); + const History* hist = reinterpret_cast(h.handle()); + const NonCompiledBatchRequest* ncbr = required_cast(const NonCompiledBatchRequest*, cbr); + const BatchRequest& request = ncbr->request; + + std::unordered_set tokens; + Score backoff = 0; + if (linear_combination_) { + backoff = std::numeric_limits::infinity(); + } + for (u32 i = 0; i < lms_.size(); ++i) { + if (cacheHist_[i].isValid()) { + continue; + } + HistorySuccessors subSuccessors = unscaled_lms_[i]->getHistorySuccessors(hist[i]); + for (const WordScore& ws : subSuccessors) { + tokens.insert(ws.token()); + } + if (linear_combination_) { + backoff = Math::scoreSum(backoff, subSuccessors.backOffScore - std::log(lms_[i]->scale())); + } + else { + backoff += subSuccessors.backOffScore * lms_[i]->scale(); + } + } + + // non-existing tokens' scores based on cached scores and backoff + verify(result.size() == cacheScores_.size()); + if (linear_combination_) { + result = cacheScores_; // assume 0-prob. here + } + else { + std::transform(cacheScores_.begin(), cacheScores_.end(), result.begin(), std::bind(std::plus(), std::placeholders::_1, backoff * ncbr->scale())); + } + + // full combined score for these existing tokens (Note: further simplified to first token only) + for (std::unordered_set::const_iterator tokId = tokens.begin(); tokId != tokens.end(); ++tokId) { + std::vector& rqsts = token2Requests_.at(*tokId); + Score tokScore = score(h, request[rqsts.front()].tokens[0]) * ncbr->scale(); + for (std::vector::const_iterator reqId = rqsts.begin(); reqId != rqsts.end(); ++reqId) { + const Request& r = request[*reqId]; + Score sco = tokScore + r.offset; + if (result[r.target] > sco) { + result[r.target] = sco; + } + } + } +} + +void CombineLanguageModel::cacheBatch(const History& h, const CompiledBatchRequest* cbr, u32 size) const { + verify(h.isValid()); + if (linear_combination_) { + cacheBatch_(h, cbr, size); + } + else { + cacheBatch_(h, cbr, size); + } +} + +bool CombineLanguageModel::fixedHistory(s32 limit) const { + for (u32 i = 0; i < lms_.size(); ++i) { + if (!unscaled_lms_[i]->fixedHistory(limit)) { + return false; + } + } + return true; +} + +bool CombineLanguageModel::isSparse(const History& h) const { + // combineLM itself is used for lookahead: only true if all subLMs are sparse + if (!h.isValid()) { + for (u32 i = 0; i < lms_.size(); ++i) { + if (!lms_[i]->isSparse(h)) { + return false; + } + } + return true; + } + + require(h.isManagedBy(historyManager_)); + const History* hist = reinterpret_cast(h.handle()); + for (u32 i = 0; i < lms_.size(); ++i) { + if (!lms_[i]->isSparse(hist[i])) { + return false; + } + } + return true; +} + +HistorySuccessors CombineLanguageModel::getHistorySuccessors(const History& h) const { + if (linear_combination_) { + return getCombinedHistorySuccessors(h); + } + else { + return getCombinedHistorySuccessors(h); + } +} + +Score CombineLanguageModel::getBackOffScore(const History& h) const { + require(h.isManagedBy(historyManager_)); + const History* hist = reinterpret_cast(h.handle()); + Score backoff = linear_combination_ ? std::numeric_limits::infinity() : 0; + + for (u32 i = 0; i < lms_.size(); ++i) { + if (linear_combination_) { + backoff = Math::scoreSum(backoff, unscaled_lms_[i]->getBackOffScore(hist[i]) - std::log(lms_[i]->scale())); + } + else { + backoff += unscaled_lms_[i]->getBackOffScore(hist[i]) * lms_[i]->scale(); + } + } + return backoff; +} + Core::Ref CombineLanguageModel::lookaheadLanguageModel() const { if (lookahead_lm_ > 0) { require_le(static_cast(lookahead_lm_), unscaled_lms_.size()); @@ -265,4 +384,202 @@ void CombineLanguageModel::setInfo(History const& hist, SearchSpaceInformation c } } +// combine sparse scores with closest behavior as actual scoring +// one HistorySuccessors for each subLM where tokens are not requested to be the same +// combined HistorySuccessors is a super-set of all sub HistorySuccessors with score combined +// in the same way as scoring (use backoff score if a token does not exist) +// TODO test efficiency and maybe improve +template +HistorySuccessors CombineLanguageModel::getCombinedHistorySuccessors(const History& h) const { + require(h.isManagedBy(historyManager_)); + const History* hist = reinterpret_cast(h.handle()); + + TokenScoreMap combineSuccessors; + std::set combineTokens; + Score backoff = 0; + if (linear) { + backoff = std::numeric_limits::infinity(); + } + + for (u32 i = 0; i < lms_.size(); ++i) { + HistorySuccessors subSuccessors = unscaled_lms_[i]->getHistorySuccessors(hist[i]); + std::set subTokens; + for (const WordScore& ws : subSuccessors) { + subTokens.insert(ws.token()); + TokenScoreMap::iterator iter = combineSuccessors.insert(std::make_pair(ws.token(), backoff)).first; + if (linear) { + iter->second = Math::scoreSum(iter->second, ws.score() - std::log(lms_[i]->scale())); + } + else { + iter->second += ws.score() * lms_[i]->scale(); + } + } + + if (combineTokens.empty()) { + combineTokens.swap(subTokens); + } + else if (subTokens.empty()) { + for (TokenScoreMap::iterator iter = combineSuccessors.begin(); iter != combineSuccessors.end(); ++iter) { + if (linear) { + iter->second = Math::scoreSum(iter->second, subSuccessors.backOffScore - std::log(lms_[i]->scale())); + } + else { + iter->second += subSuccessors.backOffScore * lms_[i]->scale(); + } + } + } + else { + std::set missTokens; + std::set_difference(combineTokens.begin(), combineTokens.end(), subTokens.begin(), subTokens.end(), std::inserter(missTokens, missTokens.begin())); + for (std::set::const_iterator it = missTokens.begin(); it != missTokens.end(); ++it) { + Score& s = combineSuccessors[*it]; + if (linear) { + s = Math::scoreSum(s, subSuccessors.backOffScore - std::log(lms_[i]->scale())); + } + else { + s += subSuccessors.backOffScore * lms_[i]->scale(); + } + } + if (subTokens.size() > missTokens.size()) { + subTokens.insert(missTokens.begin(), missTokens.end()); + combineTokens.swap(subTokens); + } + else { + missTokens.insert(subTokens.begin(), subTokens.end()); + combineTokens.swap(missTokens); + } + } + + if (linear) { + backoff = Math::scoreSum(backoff, subSuccessors.backOffScore - std::log(lms_[i]->scale())); + } + else { + backoff += subSuccessors.backOffScore * lms_[i]->scale(); + } + } + + HistorySuccessors res; + res.backOffScore = backoff; + res.reserve(combineSuccessors.size()); + for (TokenScoreMap::const_iterator iter = combineSuccessors.begin(); iter != combineSuccessors.end(); ++iter) { + res.emplace_back(iter->first, iter->second); + } + return res; +} + +template +Score CombineLanguageModel::score_(const History& history, Token w, const std::vector& lmIds) const { + History const* hist = reinterpret_cast(history.handle()); + + Score prev_score = 0.0; + bool override_score = false; + Score comb_score = linear ? std::numeric_limits::infinity() : 0.0; + + for (std::vector::const_iterator it = lmIds.begin(); it != lmIds.end(); ++it) { + Score raw_score = 0.0; + if (!override_score) { + raw_score = unscaled_lms_[*it]->score(hist[*it], w); + prev_score = raw_score; + override_score |= raw_score >= skip_thresholds_[*it]; + } + else { + raw_score = prev_score; + if (unscaled_lms_[*it]->scoreCached(history, w)) { + raw_score = unscaled_lms_[*it]->score(hist[*it], w); + } + } + if (linear) { + comb_score = Math::scoreSum(comb_score, raw_score - std::log(lms_[*it]->scale())); + } + else { + comb_score += raw_score * lms_[*it]->scale(); + } + } + return comb_score; +} + +template +void CombineLanguageModel::cacheBatch_(const History& h, const CompiledBatchRequest* cbr, u32 size) const { + cacheHist_.clear(); + cacheScores_.clear(); + verify(matchCacheHistory(h)); + // partial non-sparse LMs to be cached + std::vector cacheLmIds; + for (u32 i = 0; i < lms_.size(); ++i) { + if (cacheHist_[i].isValid()) { + cacheLmIds.push_back(i); + } + } + if (cacheLmIds.empty() || cacheLmIds.size() == lms_.size()) { + cacheHist_.clear(); + return; + } + + // cached LMs combined scoring + token to request mapping + const NonCompiledBatchRequest* ncbr = required_cast(const NonCompiledBatchRequest*, cbr); + const BatchRequest& request = ncbr->request; + cacheScores_.resize(size, Core::Type::max); + + u32 startIdx = 0; + if (token2Requests_.empty() && staticToken2Requests_.empty()) { + staticRequestSize_ = request.size(); + } + else if (!staticToken2Requests_.empty()) { + verify(staticRequestSize_ > 0 && request.size() >= staticRequestSize_); + token2Requests_ = staticToken2Requests_; + startIdx = staticRequestSize_; + } + token2Requests_.resize(lexicon()->nSyntacticTokens()); + + for (u32 idx = 0; idx < request.size(); ++idx) { + const Request& r = request[idx]; + Score sco = 0.0; + if (r.tokens.length() >= 1) { + // first token only: mostly should be just single mapping + if (idx >= startIdx) { + token2Requests_.at(r.tokens[0]->id()).push_back(idx); + } + sco += score_(h, r.tokens[0], cacheLmIds); + if (r.tokens.length() > 1) { + History hh = extendedHistory(h, r.tokens[0]); + for (u32 ti = 1;; ++ti) { + Token st = r.tokens[ti]; + sco += score_(hh, st, cacheLmIds); + if (ti + 1 >= r.tokens.length()) { + break; + } + hh = extendedHistory(hh, st); + } + } + } + sco *= ncbr->scale(); + sco += r.offset; + if (cacheScores_[r.target] > sco) { + cacheScores_[r.target] = sco; + } + } +} + +bool CombineLanguageModel::matchCacheHistory(const History& h) const { + const History* hist = reinterpret_cast(h.handle()); + if (cacheHist_.empty()) { + for (u32 i = 0; i < lms_.size(); ++i) { + if (unscaled_lms_[i]->isSparse(hist[i])) { + cacheHist_.emplace_back(); + } + else { + cacheHist_.emplace_back(hist[i]); + } + } + } + else { + for (u32 i = 0; i < lms_.size(); ++i) { + if (!unscaled_lms_[i]->isSparse(hist[i]) && !(hist[i] == cacheHist_[i])) { + return false; + } + } + } + return true; +} + } // namespace Lm diff --git a/src/Lm/CombineLm.hh b/src/Lm/CombineLm.hh index 48e100c3f..02bed9817 100644 --- a/src/Lm/CombineLm.hh +++ b/src/Lm/CombineLm.hh @@ -41,18 +41,45 @@ public: virtual Lm::Score sentenceBeginScore() const; virtual void getDependencies(Core::DependencySet& dependencies) const; - virtual History startHistory() const; - virtual History extendedHistory(History const& history, Token w) const; - virtual History reducedHistory(History const& history, u32 limit) const; - virtual Score score(History const& history, Token w) const; - virtual Score sentenceEndScore(const History& history) const; + virtual History startHistory() const; + virtual History extendedHistory(History const& history, Token w) const; + virtual History reducedHistory(History const& history, u32 limit) const; + virtual History reduceHistoryByN(History const&, u32 n) const; + virtual std::string formatHistory(const History&) const; + virtual Score score(const History& history, Token w) const; + virtual Score sentenceEndScore(const History& history) const; + + virtual void getBatch(const History& h, const CompiledBatchRequest* cbr, std::vector& result) const; + virtual void cacheBatch(const History& h, const CompiledBatchRequest* cbr, u32 size) const; + + virtual bool fixedHistory(s32 limit) const; + virtual bool isSparse(const History& h) const; + virtual HistorySuccessors getHistorySuccessors(const History& h) const; + virtual Score getBackOffScore(const History& h) const; + virtual Core::Ref lookaheadLanguageModel() const; virtual Core::Ref recombinationLanguageModel() const; - virtual void setSegment(Bliss::SpeechSegment const* s); + + virtual void setSegment(Bliss::SpeechSegment const* s); virtual void startFrame(Search::TimeframeIndex time) const; virtual void setInfo(History const& hist, SearchSpaceInformation const& info) const; +protected: + typedef std::unordered_map TokenScoreMap; + + template + HistorySuccessors getCombinedHistorySuccessors(const History& h) const; + + // also support partial LMs combined scoring + template + Score score_(const History& h, Token w, const std::vector& lmIds) const; + + template + void cacheBatch_(const History& h, const CompiledBatchRequest* cbr, u32 size) const; + + bool matchCacheHistory(const History& h) const; + private: std::vector> lms_; std::vector> unscaled_lms_; @@ -62,6 +89,20 @@ private: bool linear_combination_; int lookahead_lm_; int recombination_lm_; + + std::vector lmIds_; + + // cached scores for partial sparse lookahead (so far only single history cache: unigram) + mutable std::vector cacheHist_; + mutable std::vector cacheScores_; + // lexicon tokenId to requests mapping + mutable std::vector> token2Requests_; + + mutable u32 staticRequestSize_; + + std::vector staticCacheHist_; + std::vector staticCacheScores_; + std::vector> staticToken2Requests_; }; } // namespace Lm diff --git a/src/Lm/Compose.cc b/src/Lm/Compose.cc index d622bdb66..8c148d29b 100644 --- a/src/Lm/Compose.cc +++ b/src/Lm/Compose.cc @@ -35,7 +35,8 @@ class ComposeAutomaton : public Lm::ComposeAutomaton { Lm::History right; State(Fsa::StateId ll, Lm::History rr) - : left(ll), right(rr) {} + : left(ll), + right(rr) {} struct Hash { size_t operator()(const State& st) const { @@ -67,7 +68,10 @@ class ComposeAutomaton : public Lm::ComposeAutomaton { Fsa::ConstAutomatonRef left, Core::Ref right, Score lmScale, Score syntaxEmissionScale) - : left_(left), right_(right), lmScale_(lmScale), syntaxEmissionScale_(syntaxEmissionScale) { + : left_(left), + right_(right), + lmScale_(lmScale), + syntaxEmissionScale_(syntaxEmissionScale) { setProperties(Fsa::PropertyStorage | Fsa::PropertyCached, 0); setProperties(Fsa::PropertyAcyclic, Fsa::hasProperties(left_, Fsa::PropertyAcyclic)); } diff --git a/src/Lm/FixedQuantizationCompressedVectorFactory.hh b/src/Lm/FixedQuantizationCompressedVectorFactory.hh index cf36cc20e..d7939fb83 100644 --- a/src/Lm/FixedQuantizationCompressedVectorFactory.hh +++ b/src/Lm/FixedQuantizationCompressedVectorFactory.hh @@ -178,7 +178,9 @@ std::vector const& QuantizedFloatVectorFixedBits::data() const { } inline FixedQuantizationCompressedVectorFactory::FixedQuantizationCompressedVectorFactory(Core::Configuration const& config) - : Precursor(config), bits_per_val_(paramBitsPerVal(config)), epsilon_(paramEpsilon(config)) { + : Precursor(config), + bits_per_val_(paramBitsPerVal(config)), + epsilon_(paramEpsilon(config)) { switch (bits_per_val_) { case 8: case 16: diff --git a/src/Lm/IndexMap.hh b/src/Lm/IndexMap.hh index a9b33c322..5af27f70f 100644 --- a/src/Lm/IndexMap.hh +++ b/src/Lm/IndexMap.hh @@ -42,7 +42,8 @@ protected: void initializeMapping(InternalClassIndex nInternalClasses); InternalClassIndex newClass(); void mapToken(Token, InternalClassIndex idx); - bool isTokenMapped(Token t) const { + + bool isTokenMapped(Token t) const { return tokenMap_[t] != invalidClass; } void finalizeMapping(); diff --git a/src/Lm/LanguageModel.cc b/src/Lm/LanguageModel.cc index 38c515eda..40fcee21b 100644 --- a/src/Lm/LanguageModel.cc +++ b/src/Lm/LanguageModel.cc @@ -200,8 +200,9 @@ void LanguageModel::getBatch(const History& history, const CompiledBatchRequest* sco *= ncbr->scale(); sco += r->offset; - if (result[r->target] > sco) + if (result[r->target] > sco) { result[r->target] = sco; + } } } diff --git a/src/Lm/LanguageModel.hh b/src/Lm/LanguageModel.hh index 3ce99d236..1e652521a 100644 --- a/src/Lm/LanguageModel.hh +++ b/src/Lm/LanguageModel.hh @@ -184,7 +184,9 @@ struct Request { Score offset; Request() {} Request(const Bliss::SyntacticTokenSequence& s, u32 t, Score o = 0.0) - : tokens(s), target(t), offset(o) {} + : tokens(s), + target(t), + offset(o) {} }; typedef std::vector BatchRequest; diff --git a/src/Lm/LstmStateManager.hh b/src/Lm/LstmStateManager.hh index 1ff448538..087c6b1df 100644 --- a/src/Lm/LstmStateManager.hh +++ b/src/Lm/LstmStateManager.hh @@ -15,35 +15,127 @@ #ifndef _LM_LSTM_STATE_MANAGER_HH #define _LM_LSTM_STATE_MANAGER_HH -#include "StateManager.hh" +#include "AbstractStateManager.hh" +#include "CompressedVector.hh" namespace Lm { -class LstmStateManager : public StateManager { +template +class LstmStateManager : public AbstractStateManager { public: - using Precursor = StateManager; + using Precursor = AbstractStateManager; LstmStateManager(Core::Configuration const& config); virtual ~LstmStateManager() = default; - virtual HistoryState initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory); - virtual void mergeStates(StateVariables const& vars, - std::vector& prefix_lengths, - std::vector const& prefix_states, - FeedDict& feed_dict, - TargetList& targets); - virtual std::vector splitStates(StateVariables const& vars, - std::vector& suffix_lengths, - std::vector const& state_tensors, - CompressedVectorFactory const& vector_factory); -}; + virtual typename Precursor::HistoryState initialState(typename Precursor::StateVariables const& vars, CompressedVectorFactory const& vector_factory); + + virtual void mergeStates(typename Precursor::StateVariables const& vars, + std::vector& prefix_lengths, + std::vector const& prefix_states, + typename Precursor::FeedDict& feed_dict, + typename Precursor::TargetList& targets); + + virtual std::vector splitStates( + typename Precursor::StateVariables const& vars, + std::vector& suffix_lengths, + std::vector const& state_tensors, + CompressedVectorFactory const& vector_factory); -// inline implementations +protected: + virtual void extendFeedDict(typename Precursor::FeedDict& feed_dict, state_variable_t const& state_var, value_t& var) = 0; + virtual void extendTargets(typename Precursor::TargetList& targets, state_variable_t const& state_var) = 0; +}; -inline LstmStateManager::LstmStateManager(Core::Configuration const& config) +template +LstmStateManager::LstmStateManager(Core::Configuration const& config) : Precursor(config) { } +template +typename LstmStateManager::Precursor::HistoryState LstmStateManager::initialState( + typename Precursor::StateVariables const& vars, + CompressedVectorFactory const& vector_factory) { + typename Precursor::HistoryState result; + result.reserve(vars.size()); + + for (auto const& var : vars) { + require_gt(var.shape.size(), 0ul); + s64 state_size = var.shape.back(); + require_ge(state_size, 0); // variable must not be of unknown size + std::vector vec(state_size, 0.0f); + auto compression_param_estimator = vector_factory.getEstimator(); + compression_param_estimator->accumulate(vec.data(), vec.size()); + auto compression_params = compression_param_estimator->estimate(); + result.emplace_back(vector_factory.compress(vec.data(), vec.size(), compression_params.get())); + } + + return result; +} + +template +void LstmStateManager::mergeStates( + typename LstmStateManager::Precursor::StateVariables const& vars, + std::vector& prefix_lengths, + std::vector::Precursor::HistoryState const*> const& prefix_states, + typename LstmStateManager::Precursor::FeedDict& feed_dict, + typename LstmStateManager::Precursor::TargetList& targets) { + require_eq(prefix_states.size(), prefix_lengths.size()); + feed_dict.reserve(vars.size()); + targets.reserve(vars.size()); + + s64 batch_size = prefix_lengths.size(); + + for (size_t v = 0ul; v < vars.size(); v++) { + s64 state_size = prefix_states.front()->at(v)->size(); + value_t var_tensor = value_t::template zeros({batch_size, state_size}); + float* data = var_tensor.template data(); + + for (size_t b = 0ul; b < static_cast(batch_size); b++) { + auto const& compressed_state = prefix_states[b]->at(v); + require_eq(compressed_state->size(), static_cast(state_size)); + compressed_state->uncompress(data + b * state_size, state_size); + } + + extendFeedDict(feed_dict, vars[v], var_tensor); + extendTargets(targets, vars[v]); + } +} + +template +std::vector::Precursor::HistoryState> LstmStateManager::splitStates( + typename LstmStateManager::Precursor::StateVariables const& vars, + std::vector& suffix_lengths, + std::vector const& state_tensors, + CompressedVectorFactory const& vector_factory) { + require_eq(vars.size(), state_tensors.size()); + + std::vector result(suffix_lengths.size()); + + for (size_t r = 0ul; r < suffix_lengths.size(); r++) { + result[r].reserve(vars.size()); + suffix_lengths[r] = 1ul; + } + + for (size_t v = 0ul; v < vars.size(); v++) { + value_t const& tensor = state_tensors[v]; + require_eq(tensor.numDims(), 2); + size_t batch_size = tensor.dimSize(0); + size_t state_size = tensor.dimSize(1); + require_eq(batch_size, suffix_lengths.size()); + + for (size_t b = 0ul; b < batch_size; b++) { + auto compression_param_estimator = vector_factory.getEstimator(); + float const* data = tensor.template data(b, 0); + compression_param_estimator->accumulate(data, state_size); + auto compression_params = compression_param_estimator->estimate(); + result[b].emplace_back(vector_factory.compress(data, state_size, compression_params.get()).release()); + } + } + + return result; +} + } // namespace Lm #endif // _LM_LSTM_STATE_MANAGER_HH diff --git a/src/Lm/Makefile b/src/Lm/Makefile index 861ead682..9602d5765 100644 --- a/src/Lm/Makefile +++ b/src/Lm/Makefile @@ -15,46 +15,58 @@ LIBSPRINTLM_O = \ $(OBJDIR)/ClassLm.o \ $(OBJDIR)/CombineLm.o \ $(OBJDIR)/Compose.o \ + $(OBJDIR)/CompressedVector.o \ $(OBJDIR)/CorpusStatistics.o \ + $(OBJDIR)/FixedQuantizationCompressedVectorFactory.o \ $(OBJDIR)/IndexMap.o \ $(OBJDIR)/LanguageModel.o \ $(OBJDIR)/Module.o \ $(OBJDIR)/NNHistoryManager.o \ + $(OBJDIR)/QuantizedCompressedVectorFactory.o \ + $(OBJDIR)/RecurrentLanguageModel.o \ + $(OBJDIR)/ReducedPrecisionCompressedVectorFactory.o \ $(OBJDIR)/ReverseArpaLm.o \ $(OBJDIR)/ScaledLanguageModel.o \ $(OBJDIR)/WordlistInterface.o +#MODF AbstractStateManager.hh +#MODF DummyCompressedVectorFactory.hh ifdef MODULE_LM_ARPA LIBSPRINTLM_O += $(OBJDIR)/ArpaLm.o endif + ifdef MODULE_LM_FSA LIBSPRINTLM_O += $(OBJDIR)/FsaLm.o LIBSPRINTLM_O += $(OBJDIR)/CheatingSegmentLm.o endif + ifdef MODULE_LM_ZEROGRAM LIBSPRINTLM_O += $(OBJDIR)/Zerogram.o endif ifdef MODULE_LM_TFRNN -LIBSPRINTLM_O += $(OBJDIR)/BlasNceSoftmaxAdapter.o -LIBSPRINTLM_O += $(OBJDIR)/CompressedVector.o -LIBSPRINTLM_O += $(OBJDIR)/FixedQuantizationCompressedVectorFactory.o -LIBSPRINTLM_O += $(OBJDIR)/LstmStateManager.o -LIBSPRINTLM_O += $(OBJDIR)/NceSoftmaxAdapter.o -LIBSPRINTLM_O += $(OBJDIR)/PassthroughSoftmaxAdapter.o -LIBSPRINTLM_O += $(OBJDIR)/QuantizedBlasNceSoftmaxAdapter.o -LIBSPRINTLM_O += $(OBJDIR)/QuantizedCompressedVectorFactory.o -LIBSPRINTLM_O += $(OBJDIR)/ReducedPrecisionCompressedVectorFactory.o -LIBSPRINTLM_O += $(OBJDIR)/TransformerStateManager.o +LIBSPRINTLM_O += $(OBJDIR)/TFBlasNceSoftmaxAdapter.o +LIBSPRINTLM_O += $(OBJDIR)/TFLstmStateManager.o +LIBSPRINTLM_O += $(OBJDIR)/TFNceSoftmaxAdapter.o +LIBSPRINTLM_O += $(OBJDIR)/TFPassthroughSoftmaxAdapter.o +LIBSPRINTLM_O += $(OBJDIR)/TFQuantizedBlasNceSoftmaxAdapter.o LIBSPRINTLM_O += $(OBJDIR)/TFRecurrentLanguageModel.o -#MODF DummyCompressedVectorFactory.hh -#MODF SoftmaxAdapter.hh -#MODF StateManager.hh +LIBSPRINTLM_O += $(OBJDIR)/TFTransformerStateManager.o +#MODF TFSoftmaxAdapter.hh CXXFLAGS += $(TF_CXXFLAGS) LDFLAGS += $(TF_LDFLAGS) endif +ifdef MODULE_LM_ONNX +LIBSPRINTLM_O += $(OBJDIR)/OnnxLstmStateManager.o +LIBSPRINTLM_O += $(OBJDIR)/OnnxPassthroughSoftmaxAdapter.o +LIBSPRINTLM_O += $(OBJDIR)/OnnxNceSoftmaxAdapter.o +LIBSPRINTLM_O += $(OBJDIR)/OnnxRecurrentLanguageModel.o +#MODF OnnxStateVariable.hh +#MODF OnnxSoftmaxAdapter.hh +endif + CHECK_O = $(OBJDIR)/check.o \ ../Flf/libSprintFlf.$(a) \ ../Flf/FlfCore/libSprintFlfCore.$(a) \ @@ -63,7 +75,6 @@ CHECK_O = $(OBJDIR)/check.o \ ../Mc/libSprintMc.$(a) \ ../Bliss/libSprintBliss.$(a) \ ../Nn/libSprintNn.$(a) \ - ../Me/libSprintMe.$(a) \ ../Mm/libSprintMm.$(a) \ ../Signal/libSprintSignal.$(a) \ ../Flow/libSprintFlow.$(a) \ @@ -79,15 +90,24 @@ CHECK_O += ../Core/libSprintCore.$(a) ifdef MODULE_CART CHECK_O += ../Cart/libSprintCart.$(a) endif +ifdef MODULE_FLF_EXT +CHECK_O += ../Flf/FlfExt/libSprintFlfExt.$(a) +endif ifdef MODULE_MATH_NR CHECK_O += ../Math/Nr/libSprintMathNr.$(a) endif ifdef MODULE_PYTHON CHECK_O += ../Python/libSprintPython.$(a) endif +ifdef MODULE_FLF_EXT +CHECK_O += ../Flf/FlfExt/libSprintFlfExt.$(a) +endif ifdef MODULE_LM_TFRNN CHECK_O += ../Tensorflow/libSprintTensorflow.$(a) endif +ifdef MODULE_LM_ONNX +CHECK_O += ../Onnx/libSprintOnnx.$(a) +endif # ----------------------------------------------------------------------------- diff --git a/src/Lm/Module.cc b/src/Lm/Module.cc index 1862c9ad4..877fa99f3 100644 --- a/src/Lm/Module.cc +++ b/src/Lm/Module.cc @@ -13,9 +13,17 @@ * limitations under the License. */ #include "Module.hh" + #include #include #include "ClassLm.hh" +#include "CombineLm.hh" +#include "DummyCompressedVectorFactory.hh" +#include "FixedQuantizationCompressedVectorFactory.hh" +#include "QuantizedCompressedVectorFactory.hh" +#include "ReducedPrecisionCompressedVectorFactory.hh" +#include "SimpleHistoryLm.hh" + #ifdef MODULE_LM_ARPA #include "ArpaLm.hh" #endif @@ -29,17 +37,10 @@ #ifdef MODULE_LM_TFRNN #include "TFRecurrentLanguageModel.hh" #endif -#include "CombineLm.hh" - -#ifdef MODULE_LM_TFRNN -#include "DummyCompressedVectorFactory.hh" -#include "FixedQuantizationCompressedVectorFactory.hh" -#include "QuantizedCompressedVectorFactory.hh" -#include "ReducedPrecisionCompressedVectorFactory.hh" +#ifdef MODULE_LM_ONNX +#include "OnnxRecurrentLanguageModel.hh" #endif -#include "SimpleHistoryLm.hh" - using namespace Lm; namespace Lm { @@ -50,10 +51,11 @@ enum LanguageModelType { lmTypeZerogram, lmTypeCombine, lmTypeTFRNN, + lmTypeOnnx, lmTypeCheatingSegment, lmTypeSimpleHistory }; -} +} // namespace Lm const Core::Choice Module_::lmTypeChoice( "ARPA", lmTypeArpa, @@ -62,6 +64,7 @@ const Core::Choice Module_::lmTypeChoice( "zerogram", lmTypeZerogram, "combine", lmTypeCombine, "tfrnn", lmTypeTFRNN, + "onnx", lmTypeOnnx, "cheating-segment", lmTypeCheatingSegment, "simple-history", lmTypeSimpleHistory, Core::Choice::endMark()); @@ -89,6 +92,9 @@ Core::Ref Module_::createLanguageModel( case lmTypeCombine: result = Core::ref(new CombineLanguageModel(c, l)); break; #ifdef MODULE_LM_TFRNN case lmTypeTFRNN: result = Core::ref(new TFRecurrentLanguageModel(c, l)); break; +#endif +#ifdef MODULE_LM_ONNX + case lmTypeOnnx: result = Core::ref(new OnnxRecurrentLanguageModel(c, l)); break; #endif case lmTypeSimpleHistory: result = Core::ref(new SimpleHistoryLm(c, l)); break; default: @@ -105,7 +111,6 @@ Core::Ref Module_::createScaledLanguageModel( return languageModel ? Core::Ref(new LanguageModelScaling(c, languageModel)) : Core::Ref(); } -#ifdef MODULE_LM_TFRNN enum CompressedVectorFactoryType { DummyCompressedVectorFactoryType, FixedQuantizationCompressedVectorFactoryType, @@ -134,5 +139,3 @@ Lm::CompressedVectorFactoryPtr Module_::createCompressedVectorFactory(Cor default: defect(); } } - -#endif diff --git a/src/Lm/Module.hh b/src/Lm/Module.hh index 053e583f6..a9c3243fa 100644 --- a/src/Lm/Module.hh +++ b/src/Lm/Module.hh @@ -16,12 +16,10 @@ #define _LM_MODULE_HH #include -#include "LanguageModel.hh" -#include "ScaledLanguageModel.hh" -#ifdef MODULE_LM_TFRNN #include "CompressedVector.hh" -#endif +#include "LanguageModel.hh" +#include "ScaledLanguageModel.hh" namespace Lm { @@ -30,10 +28,9 @@ private: static const Core::Choice lmTypeChoice; static const Core::ParameterChoice lmTypeParam; -#ifdef MODULE_LM_TFRNN static const Core::Choice compressedVectorFactoryTypeChoice; static const Core::ParameterChoice compressedVectorFactoryTypeParam; -#endif + public: Module_() {} @@ -65,9 +62,7 @@ public: return createScaledLanguageModel(c, createLanguageModel(c, l)); } -#ifdef MODULE_LM_TFRNN Lm::CompressedVectorFactoryPtr createCompressedVectorFactory(Core::Configuration const& config); -#endif }; typedef Core::SingletonHolder Module; diff --git a/src/Lm/OnnxLstmStateManager.cc b/src/Lm/OnnxLstmStateManager.cc new file mode 100644 index 000000000..40db31091 --- /dev/null +++ b/src/Lm/OnnxLstmStateManager.cc @@ -0,0 +1,13 @@ +#include "OnnxLstmStateManager.hh" + +namespace Lm { + +void OnnxLstmStateManager::extendFeedDict(FeedDict& feed_dict, Onnx::OnnxStateVariable const& state_var, Onnx::Value& var) { + feed_dict.emplace_back(state_var.input_state_key, std::move(var)); +} + +void OnnxLstmStateManager::extendTargets(TargetList& targets, Onnx::OnnxStateVariable const& state_var) { + targets.emplace_back(state_var.output_state_key); +} + +} // namespace Lm diff --git a/src/Lm/OnnxLstmStateManager.hh b/src/Lm/OnnxLstmStateManager.hh new file mode 100644 index 000000000..a3b81a793 --- /dev/null +++ b/src/Lm/OnnxLstmStateManager.hh @@ -0,0 +1,31 @@ +#ifndef _LM_ONNX_LSTM_STATE_MANAGER_HH +#define _LM_ONNX_LSTM_STATE_MANAGER_HH + +#include +#include + +#include "LstmStateManager.hh" + +namespace Lm { + +class OnnxLstmStateManager : public LstmStateManager { +public: + using Precursor = LstmStateManager; + + OnnxLstmStateManager(Core::Configuration const& config); + virtual ~OnnxLstmStateManager() = default; + +protected: + virtual void extendFeedDict(FeedDict& feed_dict, Onnx::OnnxStateVariable const& state_var, Onnx::Value& var); + virtual void extendTargets(TargetList& targets, Onnx::OnnxStateVariable const& state_var); +}; + +// inline implementations + +inline OnnxLstmStateManager::OnnxLstmStateManager(Core::Configuration const& config) + : Precursor(config) { +} + +} // namespace Lm + +#endif // _LM_ONNX_LSTM_STATE_MANAGER_HH diff --git a/src/Lm/OnnxNceSoftmaxAdapter.cc b/src/Lm/OnnxNceSoftmaxAdapter.cc new file mode 100644 index 000000000..78530f100 --- /dev/null +++ b/src/Lm/OnnxNceSoftmaxAdapter.cc @@ -0,0 +1,64 @@ +#include "OnnxNceSoftmaxAdapter.hh" + +#include + +#include "DummyCompressedVectorFactory.hh" + +namespace Lm { + +const Core::ParameterString OnnxNceSoftmaxAdapter::paramWeightsFile("weights-file", "output embedding file", ""); +const Core::ParameterString OnnxNceSoftmaxAdapter::paramBiasFile("bias-file", "output bias file", ""); + +void OnnxNceSoftmaxAdapter::init(Onnx::Session& session, Onnx::IOMapping& mapping) { + Core::BinaryInputStream weight_stream(weightsFile_); + weights_.read(weight_stream); + + Core::BinaryInputStream bias_stream(biasFile_); + + u32 numRows; + bias_stream >> numRows; + + std::vector elem(numRows); + bias_stream.read(elem.data(), numRows); + + bias_.resize(numRows, 0.0f, true); + + for (size_t i = 0; i < numRows; i++) { + bias_[i] = elem[i]; + } +} + +Score OnnxNceSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { + std::vector nn_output; + float const* data; + Lm::UncompressedVector const* vec = dynamic_cast const*>(nn_out.get()); + + if (vec != nullptr) { + data = vec->data(); + } + else { + nn_output.resize(nn_out->size()); + nn_out->uncompress(nn_output.data(), nn_output.size()); + data = nn_output.data(); + } + + float result = Math::dot(nn_out->size(), data, 1, &weights_(0, output_idx), 1); + result += bias_[output_idx]; + + return result; +} + +std::vector OnnxNceSoftmaxAdapter::get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs) { + std::vector nn_output(nn_out->size()); + nn_out->uncompress(nn_output.data(), nn_output.size()); + + std::vector result(output_idxs.size()); + for (size_t i = 0ul; i < output_idxs.size(); i++) { + result[i] = Math::dot(nn_output.size(), nn_output.data(), 1, &weights_(0, output_idxs[i]), 1); + result[i] += bias_[output_idxs[i]]; + } + + return result; +} + +} // namespace Lm diff --git a/src/Lm/OnnxNceSoftmaxAdapter.hh b/src/Lm/OnnxNceSoftmaxAdapter.hh new file mode 100644 index 000000000..707832e81 --- /dev/null +++ b/src/Lm/OnnxNceSoftmaxAdapter.hh @@ -0,0 +1,40 @@ +#ifndef _LM_ONNX_NCE_SOFTMAX_ADAPTER_HH +#define _LM_ONNX_NCE_SOFTMAX_ADAPTER_HH + +#include + +#include "OnnxSoftmaxAdapter.hh" + +namespace Lm { + +class OnnxNceSoftmaxAdapter : public OnnxSoftmaxAdapter { +public: + using Precursor = OnnxSoftmaxAdapter; + + static const Core::ParameterString paramWeightsFile; + static const Core::ParameterString paramBiasFile; + + OnnxNceSoftmaxAdapter(Core::Configuration const& config); + virtual ~OnnxNceSoftmaxAdapter() = default; + + virtual void init(Onnx::Session& session, Onnx::IOMapping& mapping); + virtual Score get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx); + virtual std::vector get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs); + +private: + const std::string weightsFile_; + const std::string biasFile_; + + Math::FastMatrix weights_; + Math::FastVector bias_; +}; + +inline OnnxNceSoftmaxAdapter::OnnxNceSoftmaxAdapter(Core::Configuration const& config) + : Precursor(config), + weightsFile_(paramWeightsFile(config)), + biasFile_(paramBiasFile(config)) { +} + +} // namespace Lm + +#endif // _LM_ONNX_NCE_SOFTMAX_ADAPTER_HH diff --git a/src/Lm/OnnxPassthroughSoftmaxAdapter.cc b/src/Lm/OnnxPassthroughSoftmaxAdapter.cc new file mode 100644 index 000000000..19b0f03b3 --- /dev/null +++ b/src/Lm/OnnxPassthroughSoftmaxAdapter.cc @@ -0,0 +1,12 @@ +#include "OnnxPassthroughSoftmaxAdapter.hh" + +namespace Lm { + +void OnnxPassthroughSoftmaxAdapter::init(Onnx::Session& session, Onnx::IOMapping& mapping) { +} + +Score OnnxPassthroughSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { + return nn_out->get(output_idx); +} + +} // namespace Lm diff --git a/src/Lm/OnnxPassthroughSoftmaxAdapter.hh b/src/Lm/OnnxPassthroughSoftmaxAdapter.hh new file mode 100644 index 000000000..212f0e75f --- /dev/null +++ b/src/Lm/OnnxPassthroughSoftmaxAdapter.hh @@ -0,0 +1,31 @@ +#ifndef _LM_ONNX_PASSTHROUGH_SOFTMAX_ADAPTER_HH +#define _LM_ONNX_PASSTHROUGH_SOFTMAX_ADAPTER_HH + +#include + +#include "OnnxSoftmaxAdapter.hh" + +namespace Lm { + +class OnnxPassthroughSoftmaxAdapter : public OnnxSoftmaxAdapter { +public: + using Precursor = OnnxSoftmaxAdapter; + + OnnxPassthroughSoftmaxAdapter(Core::Configuration const& config); + virtual ~OnnxPassthroughSoftmaxAdapter() = default; + + virtual void init(Onnx::Session& session, Onnx::IOMapping& mapping); + virtual Score get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx); + +private: +}; + +// inline implementations + +inline OnnxPassthroughSoftmaxAdapter::OnnxPassthroughSoftmaxAdapter(Core::Configuration const& config) + : Precursor(config) { +} + +} // namespace Lm + +#endif /* _LM_ONNX_PASSTHROUGH_SOFTMAX_ADAPTER_HH */ diff --git a/src/Lm/OnnxRecurrentLanguageModel.cc b/src/Lm/OnnxRecurrentLanguageModel.cc new file mode 100644 index 000000000..7af03faa9 --- /dev/null +++ b/src/Lm/OnnxRecurrentLanguageModel.cc @@ -0,0 +1,113 @@ +#include "OnnxRecurrentLanguageModel.hh" + +#include "OnnxLstmStateManager.hh" +#include "OnnxNceSoftmaxAdapter.hh" +#include "OnnxPassthroughSoftmaxAdapter.hh" + +namespace { + +std::vector getIOSpec(int64_t num_classes) { + return std::vector({ + Onnx::IOSpecification{"word", Onnx::IODirection::INPUT, false, {Onnx::ValueType::TENSOR}, {Onnx::ValueDataType::INT32}, {{-1, -1}}}, + Onnx::IOSpecification{"word-length", Onnx::IODirection::INPUT, false, {Onnx::ValueType::TENSOR}, {Onnx::ValueDataType::INT32}, {{-1}}}, + Onnx::IOSpecification{"nn-output", Onnx::IODirection::OUTPUT, false, {Onnx::ValueType::TENSOR}, {Onnx::ValueDataType::FLOAT}, {{-1, -1, num_classes}}}, + }); +} + +} // namespace + +namespace Lm { + +enum OnnxStateManagerType { + OnnxLstmStateManagerType, + OnnxTransformerStateManagerType, + OnnxTransformerStateManager16BitType, + OnnxTransformerStateManager8BitType, + OnnxTransformerStateManagerWithCommonPrefixType, + OnnxTransformerStateManagerWithCommonPrefix16BitType, + OnnxTransformerStateManagerWithCommonPrefix8BitType, +}; + +const Core::Choice stateManagerTypeChoice( + "lstm", OnnxLstmStateManagerType, + Core::Choice::endMark()); + +const Core::ParameterChoice stateManagerTypeParam( + "type", &stateManagerTypeChoice, + "type of the state manager", + OnnxLstmStateManagerType); + +std::unique_ptr createOnnxStateManager(Core::Configuration const& config) { + OnnxStateManager* res = nullptr; + switch (stateManagerTypeParam(config)) { + case OnnxLstmStateManagerType: res = new Lm::OnnxLstmStateManager(config); break; + default: defect(); + } + return std::unique_ptr(res); +} + +enum OnnxSoftmaxAdapterType { + OnnxPassthroughSoftmaxAdapterType, + OnnxNceSoftmaxAdapterType, +}; + +const Core::Choice softmaxAdapterTypeChoice( + "passthrough", OnnxPassthroughSoftmaxAdapterType, + "nce", OnnxNceSoftmaxAdapterType, + Core::Choice::endMark()); + +const Core::ParameterChoice softmaxAdapterTypeParam( + "type", &softmaxAdapterTypeChoice, + "type of the softmax adapter", + OnnxPassthroughSoftmaxAdapterType); + +std::unique_ptr createOnnxSoftmaxAdapter(Core::Configuration const& config) { + switch (softmaxAdapterTypeParam(config)) { + case OnnxPassthroughSoftmaxAdapterType: return std::unique_ptr(new Lm::OnnxPassthroughSoftmaxAdapter(config)); + case OnnxNceSoftmaxAdapterType: return std::unique_ptr(new Lm::OnnxNceSoftmaxAdapter(config)); + default: defect(); + } +} + +OnnxRecurrentLanguageModel::OnnxRecurrentLanguageModel(Core::Configuration const& c, Bliss::LexiconRef l) + : Core::Component(c), + Precursor(c, l, createOnnxStateManager(select("state-manager"))), + session_(select("session")), + ioSpec_(getIOSpec(-2)), + mapping_(select("io-map"), ioSpec_), + validator_(select("validator")), + softmax_adapter_(createOnnxSoftmaxAdapter(select("softmax-adapter"))) { + state_variables_ = session_.getStateVariablesMetadata(); + + setEmptyHistory(); + + softmax_adapter_->init(session_, mapping_); + validator_.validate(ioSpec_, mapping_, session_); +} + +void OnnxRecurrentLanguageModel::setState(std::vector> const& inputs, std::vector const& targets) const { +} + +void OnnxRecurrentLanguageModel::extendInputs(std::vector>& inputs, Math::FastMatrix const& words, Math::FastVector const& word_lengths, std::vector const& state_lengths) const { + inputs.emplace_back(mapping_.getOnnxName("word"), Onnx::Value::create(words)); + inputs.emplace_back(mapping_.getOnnxName("word-length"), Onnx::Value::create(word_lengths)); +} + +void OnnxRecurrentLanguageModel::extendTargets(std::vector& targets) const { + targets.emplace(targets.begin(), mapping_.getOnnxName("nn-output")); +} + +void OnnxRecurrentLanguageModel::getOutputs(std::vector>& inputs, std::vector& outputs, std::vector const& targets) const { + session_.run(std::move(inputs), targets, outputs); +} + +std::vector OnnxRecurrentLanguageModel::fetchStates(std::vector& outputs) const { + std::vector state_vars(std::make_move_iterator(outputs.begin() + 1), std::make_move_iterator(outputs.end())); + return state_vars; +} + +Score OnnxRecurrentLanguageModel::transformOutput(Lm::CompressedVectorPtr const& nn_output, size_t index) const { + return softmax_adapter_->get_score(nn_output, index); +} + +} // namespace Lm diff --git a/src/Lm/OnnxRecurrentLanguageModel.hh b/src/Lm/OnnxRecurrentLanguageModel.hh new file mode 100644 index 000000000..7485a493a --- /dev/null +++ b/src/Lm/OnnxRecurrentLanguageModel.hh @@ -0,0 +1,40 @@ +#ifndef _LM_ONNX_RECURRENT_LANGUAGE_MODEL_HXX +#define _LM_ONNX_RECURRENT_LANGUAGE_MODEL_HXX + +#include +#include +#include + +#include "OnnxSoftmaxAdapter.hh" +#include "RecurrentLanguageModel.hh" + +namespace Lm { + +class OnnxRecurrentLanguageModel : public RecurrentLanguageModel { +public: + using Precursor = RecurrentLanguageModel; + + OnnxRecurrentLanguageModel(Core::Configuration const& c, Bliss::LexiconRef l); + virtual ~OnnxRecurrentLanguageModel() {} + +protected: + virtual void setState(std::vector> const& inputs, std::vector const& targets) const; + virtual void extendInputs(std::vector>& inputs, Math::FastMatrix const& words, Math::FastVector const& word_lengths, std::vector const& state_lengths) const; + virtual void extendTargets(std::vector& targets) const; + virtual void getOutputs(std::vector>& inputs, std::vector& outputs, std::vector const& targets) const; + virtual std::vector fetchStates(std::vector& outputs) const; + + virtual Score transformOutput(Lm::CompressedVectorPtr const& nn_output, size_t index) const; + +private: + mutable Onnx::Session session_; + std::vector ioSpec_; + Onnx::IOMapping mapping_; + Onnx::IOValidator validator_; + + std::unique_ptr softmax_adapter_; +}; + +} // namespace Lm + +#endif // _LM_ONNX_RECURRENT_LANGUAGE_MODEL_HXX diff --git a/src/Lm/OnnxSoftmaxAdapter.hh b/src/Lm/OnnxSoftmaxAdapter.hh new file mode 100644 index 000000000..ed9218b10 --- /dev/null +++ b/src/Lm/OnnxSoftmaxAdapter.hh @@ -0,0 +1,45 @@ +#ifndef _LM_ONNX_SOFTMAX_ADAPTER_HH +#define _LM_ONNX_SOFTMAX_ADAPTER_HH + +#include +#include +#include + +#include "CompressedVector.hh" + +namespace Lm { + +using Score = float; + +class OnnxSoftmaxAdapter : public Core::Component { +public: + using Precursor = Core::Component; + + OnnxSoftmaxAdapter(Core::Configuration const& config); + virtual ~OnnxSoftmaxAdapter() = default; + + virtual void init(Onnx::Session& session, Onnx::IOMapping& mapping) = 0; + virtual Score get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) = 0; + virtual std::vector get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs); + +private: +}; + +// inline implementations + +inline OnnxSoftmaxAdapter::OnnxSoftmaxAdapter(Core::Configuration const& config) + : Precursor(config) { +} + +inline std::vector OnnxSoftmaxAdapter::get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs) { + std::vector scores; + scores.reserve(output_idxs.size()); + for (size_t output_idx : output_idxs) { + scores.push_back(get_score(nn_out, output_idx)); + } + return scores; +} + +} // namespace Lm + +#endif /* _LM_ONNX_SOFTMAX_ADAPTER_HH */ diff --git a/src/Lm/OnnxStateVariable.hh b/src/Lm/OnnxStateVariable.hh new file mode 100644 index 000000000..a74f03294 --- /dev/null +++ b/src/Lm/OnnxStateVariable.hh @@ -0,0 +1,18 @@ +#ifndef _LM_ONNX_STATE_VARIABLE_HH +#define _LM_ONNX_STATE_VARIABLE_HH + +#include + +#include + +namespace Lm { + +struct OnnxStateVariable { + std::string input_state_key; + std::string output_state_key; + std::vector shape; +}; + +} // namespace Lm + +#endif // _LM_ONNX_STATE_VARIABLE_HH \ No newline at end of file diff --git a/src/Lm/QuantizedCompressedVectorFactory.hh b/src/Lm/QuantizedCompressedVectorFactory.hh index d05045edf..d0bec1d61 100644 --- a/src/Lm/QuantizedCompressedVectorFactory.hh +++ b/src/Lm/QuantizedCompressedVectorFactory.hh @@ -59,7 +59,8 @@ private: struct QuantizedCompressionParameters : public CompressionParameters { QuantizedCompressionParameters(float min_val, float max_val) - : min_val(min_val), max_val(max_val) {} + : min_val(min_val), + max_val(max_val) {} virtual ~QuantizedCompressionParameters() = default; float min_val; @@ -89,7 +90,8 @@ public: static const Core::ParameterInt paramBitsPerVal; QuantizedCompressedVectorFactory(Core::Configuration const& config) - : Precursor(config), bits_per_val_(paramBitsPerVal(config)) {} + : Precursor(config), + bits_per_val_(paramBitsPerVal(config)) {} virtual ~QuantizedCompressedVectorFactory() = default; virtual CompressionParameterEstimatorPtr getEstimator() const; diff --git a/src/Lm/RecurrentLanguageModel.cc b/src/Lm/RecurrentLanguageModel.cc new file mode 100644 index 000000000..77a6ca0a4 --- /dev/null +++ b/src/Lm/RecurrentLanguageModel.cc @@ -0,0 +1,146 @@ +#include "RecurrentLanguageModel.hh" + +namespace Lm::detail { + +void RequestGraph::add_cache(ScoresWithContext* cache) { + std::vector request_chain; + request_chain.push_back(cache); + ScoresWithContext* parent = const_cast(reinterpret_cast(cache->parent.handle())); + request_chain.push_back(parent); + while (parent->state.empty()) { + parent = const_cast(reinterpret_cast(parent->parent.handle())); + request_chain.push_back(parent); + } + + std::vector* child_idxs = &roots; + while (not request_chain.empty()) { + // find root node + size_t child_idx = child_idxs->size(); + for (size_t c = 0ul; c < child_idxs->size(); c++) { + if (entries[child_idxs->at(c)] == request_chain.back()) { + child_idx = c; + break; + } + } + size_t next_child_idx = 0ul; + if (child_idx == child_idxs->size()) { + child_idxs->push_back(entries.size()); + entries.push_back(request_chain.back()); + next_child_idx = child_idxs->at(child_idx); + children.emplace_back(); // can invalidate child_idxs + } + else { + next_child_idx = child_idxs->at(child_idx); + } + child_idxs = &children[next_child_idx]; + request_chain.pop_back(); + } +} + +void RequestGraph::get_requests_dfs(std::vector& requests, ScoresWithContext* initial, size_t entry, size_t length) const { + if (children[entry].empty()) { + requests.emplace_back(FwdRequest{initial, entries[entry], length}); + } + else { + for (size_t e : children[entry]) { + get_requests_dfs(requests, initial, e, length + 1ul); + } + } +} + +std::vector RequestGraph::get_requests() const { + std::vector result; + for (size_t r : roots) { + for (size_t c : children[r]) { + get_requests_dfs(result, entries[r], c, 1ul); + } + } + return result; +} + +void dump_scores(ScoresWithContext const& cache, std::string const& prefix) { + std::stringstream path; + path << prefix; + for (auto token : *cache.history) { + path << "_" << token; + } + std::ofstream out(path.str(), std::ios::out | std::ios::trunc); + out << "nn_output:\n"; + std::vector nn_output(cache.nn_output->size()); + cache.nn_output->uncompress(nn_output.data(), nn_output.size()); + for (auto nn_out : nn_output) { + out << nn_out << '\n'; + } + for (size_t s = 0ul; s < cache.state.size(); s++) { + out << "state " << s << ":\n"; + std::vector state_data(cache.state[s]->size()); + cache.state[s]->uncompress(state_data.data(), state_data.size()); + for (auto v : state_data) { + out << v << '\n'; + } + } +} + +TimeStatistics TimeStatistics::operator+(TimeStatistics const& other) const { + TimeStatistics res; + + res.total_duration = total_duration + other.total_duration; + res.early_request_duration = early_request_duration + other.early_request_duration; + res.request_duration = request_duration + other.request_duration; + res.prepare_duration = prepare_duration + other.prepare_duration; + res.merge_state_duration = merge_state_duration + other.merge_state_duration; + res.set_state_duration = set_state_duration + other.set_state_duration; + res.run_nn_output_duration = run_nn_output_duration + other.run_nn_output_duration; + res.set_nn_output_duration = set_nn_output_duration + other.set_nn_output_duration; + res.get_new_state_duration = get_new_state_duration + other.get_new_state_duration; + res.split_state_duration = split_state_duration + other.split_state_duration; + res.softmax_output_duration = softmax_output_duration + other.softmax_output_duration; + + return res; +} + +TimeStatistics& TimeStatistics::operator+=(TimeStatistics const& other) { + total_duration += other.total_duration; + early_request_duration += other.early_request_duration; + request_duration += other.request_duration; + prepare_duration += other.prepare_duration; + merge_state_duration += other.merge_state_duration; + set_state_duration += other.set_state_duration; + run_nn_output_duration += other.run_nn_output_duration; + set_nn_output_duration += other.set_nn_output_duration; + get_new_state_duration += other.get_new_state_duration; + split_state_duration += other.split_state_duration; + softmax_output_duration += other.softmax_output_duration; + + return *this; +} + +void TimeStatistics::write(Core::XmlChannel& channel) const { + channel << Core::XmlOpen("total-duration") + Core::XmlAttribute("unit", "milliseconds") << total_duration.count() << Core::XmlClose("total-duration"); + channel << Core::XmlOpen("early-request-duration") + Core::XmlAttribute("unit", "milliseconds") << early_request_duration.count() << Core::XmlClose("early-request-duration"); + channel << Core::XmlOpen("request-duration") + Core::XmlAttribute("unit", "milliseconds") << request_duration.count() << Core::XmlClose("request-duration"); + channel << Core::XmlOpen("prepare-duration") + Core::XmlAttribute("unit", "milliseconds") << prepare_duration.count() << Core::XmlClose("prepare-duration"); + channel << Core::XmlOpen("merge-state-duration") + Core::XmlAttribute("unit", "milliseconds") << merge_state_duration.count() << Core::XmlClose("merge-state-duration"); + channel << Core::XmlOpen("set-state-duration") + Core::XmlAttribute("unit", "milliseconds") << set_state_duration.count() << Core::XmlClose("set-state-duration"); + channel << Core::XmlOpen("run-nn-output-duration") + Core::XmlAttribute("unit", "milliseconds") << run_nn_output_duration.count() << Core::XmlClose("run-nn-output-duration"); + channel << Core::XmlOpen("set-nn-output-duration") + Core::XmlAttribute("unit", "milliseconds") << set_nn_output_duration.count() << Core::XmlClose("set-nn-output-duration"); + channel << Core::XmlOpen("get-new-state-duration") + Core::XmlAttribute("unit", "milliseconds") << get_new_state_duration.count() << Core::XmlClose("get-new-state-duration"); + channel << Core::XmlOpen("split-state-duration") + Core::XmlAttribute("unit", "milliseconds") << split_state_duration.count() << Core::XmlClose("split-state-duration"); + channel << Core::XmlOpen("softmax-output-duration") + Core::XmlAttribute("unit", "milliseconds") << softmax_output_duration.count() << Core::XmlClose("softmax-output-duration"); +} + +void TimeStatistics::write(std::ostream& out) const { + out << "fwd: " << total_duration.count() + << " er:" << early_request_duration.count() + << " r:" << request_duration.count() + << " p:" << prepare_duration.count() + << " ms: " << merge_state_duration.count() + << " sst:" << set_state_duration.count() + << " rs:" << run_nn_output_duration.count() + << " sno:" << set_nn_output_duration.count() + << " gns:" << get_new_state_duration.count() + << " ss: " << split_state_duration.count() + << " smo:" << softmax_output_duration.count(); +} + +} // namespace Lm::detail diff --git a/src/Lm/RecurrentLanguageModel.hh b/src/Lm/RecurrentLanguageModel.hh new file mode 100644 index 000000000..a1780e557 --- /dev/null +++ b/src/Lm/RecurrentLanguageModel.hh @@ -0,0 +1,887 @@ +#ifndef _LM_RECURRENT_LANGUAGE_MODEL_HH +#define _LM_RECURRENT_LANGUAGE_MODEL_HH + +#include +#include +#include +#include + +#include +#include + +#include "AbstractNNLanguageModel.hh" +#include "AbstractStateManager.hh" +#include "Module.hh" +#include "SearchSpaceAwareLanguageModel.hh" + +namespace Lm { +template +class RecurrentLanguageModel; + +namespace detail { + +struct ScoresWithContext : public Lm::NNCacheWithStats { + virtual ~ScoresWithContext() = default; + + std::atomic computed; + Lm::History parent; + Lm::CompressedVectorPtr nn_output; + std::vector> state; + Lm::SearchSpaceInformation info; + Search::TimeframeIndex last_used; + Search::TimeframeIndex last_info; + bool was_expanded; +}; + +struct FwdRequest { + ScoresWithContext* initial_cache; + ScoresWithContext* final_cache; + size_t length; + + bool operator==(FwdRequest const& other) const { + return final_cache == other.final_cache; + } +}; + +struct RequestGraph { + std::vector entries; + std::vector> children; + std::vector roots; + + void add_cache(ScoresWithContext* cache); + void get_requests_dfs(std::vector& requests, ScoresWithContext* initial, size_t entry, size_t length) const; + std::vector get_requests() const; +}; + +void dump_scores(ScoresWithContext const& cache, std::string const& prefix); + +template +void clear_queue(typename Lm::RecurrentLanguageModel::HistoryQueue& queue) { + Lm::History const* hist = nullptr; + while (queue.try_dequeue(hist)) { + delete hist; + } +} + +struct TimeStatistics { + std::chrono::duration total_duration; + std::chrono::duration early_request_duration; + std::chrono::duration request_duration; + std::chrono::duration prepare_duration; + std::chrono::duration merge_state_duration; + std::chrono::duration set_state_duration; + std::chrono::duration run_nn_output_duration; + std::chrono::duration set_nn_output_duration; + std::chrono::duration get_new_state_duration; + std::chrono::duration split_state_duration; + std::chrono::duration softmax_output_duration; + + TimeStatistics operator+(TimeStatistics const& other) const; + TimeStatistics& operator+=(TimeStatistics const& other); + void write(Core::XmlChannel& channel) const; + void write(std::ostream& out) const; +}; + +} // namespace detail + +template +class RecurrentLanguageModel : public AbstractNNLanguageModel, public SearchSpaceAwareLanguageModel { +public: + typedef AbstractNNLanguageModel Precursor; + typedef moodycamel::BlockingReaderWriterQueue HistoryQueue; + + static const Core::ParameterBool paramTransformOuputLog; + static const Core::ParameterBool paramTransformOuputNegate; + static const Core::ParameterInt paramMinBatchSize; + static const Core::ParameterInt paramOptBatchSize; + static const Core::ParameterInt paramMaxBatchSize; + static const Core::ParameterInt paramHistoryPruningThreshold; + static const Core::ParameterInt paramPrunedHistoryLength; + static const Core::ParameterFloat paramBatchPruningThreshold; + static const Core::ParameterBool paramAllowReducedHistory; + static const Core::ParameterBool paramDumpInputs; + static const Core::ParameterString paramDumpInputsPrefix; + static const Core::ParameterBool paramDumpScores; + static const Core::ParameterString paramDumpScoresPrefix; + static const Core::ParameterBool paramLogMemory; + static const Core::ParameterBool paramFreeMemory; + static const Core::ParameterInt paramFreeMemoryDelay; + static const Core::ParameterBool paramAsync; + static const Core::ParameterBool paramSingleStepOnly; + static const Core::ParameterBool paramVerbose; + + RecurrentLanguageModel(Core::Configuration const& c, Bliss::LexiconRef l, std::unique_ptr> state_manager); + virtual ~RecurrentLanguageModel(); + + virtual History startHistory() const; + virtual History extendedHistory(History const& hist, Token w) const; + virtual History extendedHistory(History const& hist, Bliss::Token::Id w) const; + virtual History reducedHistory(History const& hist, u32 limit) const; + virtual History reduceHistoryByN(History const&, u32 n) const; + virtual Score score(History const& hist, Token w) const; + virtual bool scoreCached(History const& hist, Token w) const; + + virtual void startFrame(Search::TimeframeIndex time) const; + virtual void setInfo(History const& hist, SearchSpaceInformation const& info) const; + +protected: + virtual void load(); + + virtual void setState(std::vector> const& inputs, std::vector const& targets) const = 0; + virtual void extendInputs(std::vector>& inputs, Math::FastMatrix const& words, Math::FastVector const& word_lengths, std::vector const& state_lengths) const = 0; + virtual void extendTargets(std::vector& targets) const = 0; + virtual void getOutputs(std::vector>& inputs, std::vector& outputs, std::vector const& targets) const = 0; + virtual std::vector fetchStates(std::vector& outputs) const = 0; + virtual Score transformOutput(Lm::CompressedVectorPtr const& nn_output, size_t index) const = 0; + + std::vector state_variables_; + + void setEmptyHistory(); + +private: + using ScoresWithContext = detail::ScoresWithContext; + + bool transform_output_log_; + bool transform_output_negate_; + size_t min_batch_size_; + size_t opt_batch_size_; + size_t max_batch_size_; + size_t history_pruning_threshold_; + size_t pruned_history_length_; + Score batch_pruning_threshold_; + bool allow_reduced_history_; + bool dump_inputs_; + std::string dump_inputs_prefix_; + bool dump_scores_; + std::string dump_scores_prefix_; + bool log_memory_; + bool free_memory_; + Search::TimeframeIndex free_memory_delay_; + bool single_step_only_; + bool verbose_; + + mutable Core::XmlChannel statistics_; + mutable Search::TimeframeIndex current_time_; + mutable std::vector run_time_; + mutable std::vector run_count_; + mutable double total_wait_time_; + mutable double total_start_frame_time_; + mutable double total_expand_hist_time_; + mutable detail::TimeStatistics fwd_statistics_; + mutable size_t dump_inputs_counter_; + + std::unique_ptr> state_manager_; + + std::function output_transform_function_; + CompressedVectorFactoryPtr state_comp_vec_factory_; + CompressedVectorFactoryPtr nn_output_comp_vec_factory_; + + History empty_history_; // a history used to provide the previous (all zero) state to the first real history (1 sentence-begin token) + + // members for async forwarding + bool should_stop_; + std::thread background_forwarder_thread_; + bool async_; + + void background_forward() const; + + mutable std::atomic to_fwd_; + mutable std::promise to_fwd_finished_; + + mutable std::vector pending_; + mutable HistoryQueue fwd_queue_; + mutable HistoryQueue finished_queue_; + + History extendHistoryWithOutputIdx(History const& hist, size_t w) const; + + template + void forward(Lm::History const* hist) const; +}; + +template +const Core::ParameterBool RecurrentLanguageModel::paramTransformOuputLog("transform-output-log", "apply log to tensorflow output", false); + +template +const Core::ParameterBool RecurrentLanguageModel::paramTransformOuputNegate("transform-output-negate", "negate tensorflow output (after log)", false); + +template +const Core::ParameterInt RecurrentLanguageModel::paramMinBatchSize("min-batch-size", "minimum number of histories forwarded in one go", 32); + +template +const Core::ParameterInt RecurrentLanguageModel::paramOptBatchSize("opt-batch-size", "optimum number of histories forwarded in one go", 128); + +template +const Core::ParameterInt RecurrentLanguageModel::paramMaxBatchSize("max-batch-size", "maximum number of histories forwarded in one go", 2048); + +template +const Core::ParameterInt RecurrentLanguageModel::paramHistoryPruningThreshold("history-pruning-threshold", "if the history is longer than this parameter it will be pruned", std::numeric_limits::max(), 0); + +template +const Core::ParameterInt RecurrentLanguageModel::paramPrunedHistoryLength("pruned-history-length", "length of the pruned history (should be smaller than history-pruning-threshold)", std::numeric_limits::max(), 0); + +template +const Core::ParameterFloat RecurrentLanguageModel::paramBatchPruningThreshold("batch-pruning-threshold", "pruning threshold for all hypothesis beyond min-batch-size during eager forwarding", 10.0); + +template +const Core::ParameterBool RecurrentLanguageModel::paramAllowReducedHistory("allow-reduced-history", "wether this LM will actually reduce the history length", false); + +template +const Core::ParameterBool RecurrentLanguageModel::paramDumpInputs("dump-inputs", "write all inputs from this LM to disk", false); + +template +const Core::ParameterString RecurrentLanguageModel::paramDumpInputsPrefix("dump-inputs-prefix", "prefix for the input dumps", "inputs"); + +template +const Core::ParameterBool RecurrentLanguageModel::paramDumpScores("dump-scores", "write all scores from this LM to disk", false); + +template +const Core::ParameterString RecurrentLanguageModel::paramDumpScoresPrefix("dump-scores-prefix", "prefix for the score dumps", "scores"); + +template +const Core::ParameterBool RecurrentLanguageModel::paramLogMemory("log-memory", "wether memory usage from nn-outputs / states should be logged", false); + +template +const Core::ParameterBool RecurrentLanguageModel::paramFreeMemory("free-memory", "wether nn-outputs should be deleted after some delay", false); + +template +const Core::ParameterInt RecurrentLanguageModel::paramFreeMemoryDelay("free-memory-delay", "how many time frames without usage before nn-outputs are deleted", 40); + +template +const Core::ParameterBool RecurrentLanguageModel::paramAsync("async", "wether to forward histories in a separate thread", false); + +template +const Core::ParameterBool RecurrentLanguageModel::paramSingleStepOnly("single-step-only", "workaround for some bug that results in wrong scores when recombination is done in combination with async evaluation", false); + +template +const Core::ParameterBool RecurrentLanguageModel::paramVerbose("verbose", "wether to print detailed statistics to stderr", false); + +template +RecurrentLanguageModel::RecurrentLanguageModel(Core::Configuration const& c, Bliss::LexiconRef l, std::unique_ptr> state_manager) + : Core::Component(c), + Precursor(c, l), + transform_output_log_(paramTransformOuputLog(config)), + transform_output_negate_(paramTransformOuputNegate(config)), + min_batch_size_(paramMinBatchSize(config)), + opt_batch_size_(paramOptBatchSize(config)), + max_batch_size_(paramMaxBatchSize(config)), + history_pruning_threshold_(paramHistoryPruningThreshold(config)), + pruned_history_length_(paramPrunedHistoryLength(config)), + batch_pruning_threshold_(paramBatchPruningThreshold(config)), + allow_reduced_history_(paramAllowReducedHistory(config)), + dump_inputs_(paramDumpInputs(config)), + dump_inputs_prefix_(paramDumpInputsPrefix(config)), + dump_scores_(paramDumpScores(config)), + dump_scores_prefix_(paramDumpScoresPrefix(config)), + log_memory_(paramLogMemory(config)), + free_memory_(paramFreeMemory(config)), + free_memory_delay_(paramFreeMemoryDelay(config)), + single_step_only_(paramSingleStepOnly(config)), + verbose_(paramVerbose(config)), + statistics_(config, "statistics"), + current_time_(0u), + run_time_(max_batch_size_, 0.0), + run_count_(max_batch_size_, 0ul), + total_wait_time_(0.0), + total_start_frame_time_(0.0), + total_expand_hist_time_(0.0), + fwd_statistics_(), + dump_inputs_counter_(0ul), + state_manager_(std::move(state_manager)), + output_transform_function_(), + state_comp_vec_factory_(Lm::Module::instance().createCompressedVectorFactory(select("state-compression"))), + nn_output_comp_vec_factory_(Lm::Module::instance().createCompressedVectorFactory(select("nn-output-compression"))), + empty_history_(), + should_stop_(false), + background_forwarder_thread_(), + async_(paramAsync(config)), + to_fwd_(nullptr), + to_fwd_finished_(), + pending_(), + fwd_queue_(32768), + finished_queue_(32768) { + if (transform_output_log_ and transform_output_negate_) { + output_transform_function_ = [](Score v) { + return -std::log(v); + }; + } + else if (transform_output_log_) { + output_transform_function_ = [](Score v) { + return std::log(v); + }; + } + else if (transform_output_negate_) { + output_transform_function_ = [](Score v) { + return -v; + }; + } + + if (async_) { + background_forwarder_thread_ = std::thread(std::bind(&RecurrentLanguageModel::background_forward, this)); + } + + require_le(pruned_history_length_, history_pruning_threshold_); +} + +template +RecurrentLanguageModel::~RecurrentLanguageModel() { + detail::clear_queue(finished_queue_); + + if (async_) { + should_stop_ = true; + background_forwarder_thread_.join(); + } + + size_t total_run_count = 0ul; + size_t total_fwd_hist = 0ul; + double total_run_time = 0.0; + + statistics_ << Core::XmlOpen("fwd-time"); + for (size_t i = 0ul; i < run_count_.size(); i++) { + if (run_count_[i] > 0ul) { + statistics_ << (i + 1) << " " << run_count_[i] << " " << run_time_[i] << "\n"; + total_run_count += run_count_[i]; + total_fwd_hist += (i + 1) * run_count_[i]; + total_run_time += run_time_[i]; + } + } + statistics_ << Core::XmlClose("fwd-time"); + + statistics_ << Core::XmlOpen("fwd-summary"); + statistics_ << Core::XmlOpen("total-run-count") << total_run_count << Core::XmlClose("total-run-count"); + statistics_ << Core::XmlOpen("total-fwd-hist") << total_fwd_hist << Core::XmlClose("total-fwd-hist"); + statistics_ << Core::XmlOpen("total-run-time") + Core::XmlAttribute("unit", "milliseconds") << total_run_time << Core::XmlClose("total-run-time"); + statistics_ << Core::XmlOpen("total-wait-time") + Core::XmlAttribute("unit", "milliseconds") << total_wait_time_ << Core::XmlClose("total-wait-time"); + statistics_ << Core::XmlOpen("total-start-frame-time") + Core::XmlAttribute("unit", "milliseconds") << total_start_frame_time_ << Core::XmlClose("total-start-frame-time"); + statistics_ << Core::XmlOpen("total-expand-hist-time") + Core::XmlAttribute("unit", "milliseconds") << total_expand_hist_time_ << Core::XmlClose("total-expand-hist-time"); + statistics_ << Core::XmlOpen("fwd-times"); + fwd_statistics_.write(statistics_); + statistics_ << Core::XmlClose("fwd-times"); + statistics_ << Core::XmlClose("fwd-summary"); +} + +template +History RecurrentLanguageModel::startHistory() const { + NNHistoryManager* hm = dynamic_cast(historyManager_); + TokenIdSequence ts(1ul, lexicon_mapping_[sentenceBeginToken()->id()]); + HistoryHandle h = hm->get(ts); + ScoresWithContext* cache = const_cast(reinterpret_cast(h)); + cache->parent = empty_history_; + History hist(history(h)); + return hist; +} + +template +void RecurrentLanguageModel::setEmptyHistory() { + NNHistoryManager* hm = dynamic_cast(historyManager_); + TokenIdSequence ts; + HistoryHandle h = hm->get(ts); + ScoresWithContext* cache = const_cast(reinterpret_cast(h)); + cache->state = state_manager_->initialState(state_variables_, *state_comp_vec_factory_); + + if (cache->state.empty()) { + error("LM has no state variables. Did you forget to compile with 'initial_state': 'keep_over_epoch_no_init' for TensorFlow or 'initial_state': 'placeholder' for Onnx?"); + } + + std::vector temp(1); + auto compression_param_estimator = nn_output_comp_vec_factory_->getEstimator(); + compression_param_estimator->accumulate(temp.data(), temp.size()); + auto compression_params = compression_param_estimator->estimate(); + // pretend this history has already been evaluated + cache->nn_output = nn_output_comp_vec_factory_->compress(temp.data(), temp.size(), compression_params.get()); + cache->computed.store(true); + cache->last_used = std::numeric_limits::max(); + empty_history_ = history(h); +} + +template +History RecurrentLanguageModel::extendedHistory(History const& hist, Token w) const { + return extendedHistory(hist, w->id()); +} + +template +History RecurrentLanguageModel::extendedHistory(History const& hist, Bliss::Token::Id w) const { + return extendHistoryWithOutputIdx(hist, lexicon_mapping_[w]); +} + +template +History RecurrentLanguageModel::reducedHistory(History const& hist, u32 limit) const { + ScoresWithContext const* sc = reinterpret_cast(hist.handle()); + if (not allow_reduced_history_ or sc->history->size() <= limit) { + return hist; + } + History h = startHistory(); + for (u32 w = limit; w > 0; w--) { + h = extendHistoryWithOutputIdx(h, sc->history->at(sc->history->size() - w)); + } + return h; +} + +template +History RecurrentLanguageModel::reduceHistoryByN(History const& hist, u32 n) const { + if (not allow_reduced_history_) { + return hist; + } + ScoresWithContext const* sc = reinterpret_cast(hist.handle()); + History h = startHistory(); + for (u32 w = n; w < sc->history->size(); w++) { + h = extendHistoryWithOutputIdx(h, sc->history->at(w)); + } + return h; +} + +template +Score RecurrentLanguageModel::score(History const& hist, Token w) const { + ScoresWithContext* sc = const_cast(reinterpret_cast(hist.handle())); + + if (not sc->computed.load()) { + auto start = std::chrono::steady_clock::now(); + if (async_) { + // promise should only be used once + to_fwd_finished_ = std::promise(); + std::future future = to_fwd_finished_.get_future(); + to_fwd_.store(&hist); + future.wait(); + } + else { + forward(&hist); + } + auto end = std::chrono::steady_clock::now(); + double wait_time = std::chrono::duration(end - start).count(); + total_wait_time_ += wait_time; + if (verbose_) { + std::cerr << "wait: " << wait_time << " " << sc->info.numStates << " " << sc->info.bestScoreOffset << std::endl; + } + } + + require(sc->computed.load()); + + size_t output_idx = lexicon_mapping_[w->id()]; + useOutput(*sc, output_idx); + sc->last_used = current_time_; + auto start = std::chrono::steady_clock::now(); + Score score = output_transform_function_(transformOutput(sc->nn_output, output_idx)); + auto end = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration(end - start); + fwd_statistics_.softmax_output_duration += duration; + fwd_statistics_.total_duration += duration; + return score; +} + +template +bool RecurrentLanguageModel::scoreCached(History const& hist, Token w) const { + ScoresWithContext const* sc = reinterpret_cast(hist.handle()); + return sc->computed.load(); +} + +template +void RecurrentLanguageModel::load() { + loadVocabulary(); +} + +template +void RecurrentLanguageModel::startFrame(Search::TimeframeIndex time) const { + auto timer_start = std::chrono::steady_clock::now(); + + current_time_ = time; + + size_t nn_output_cache_size = 0ul; + size_t state_cache_size = 0ul; + size_t num_histories = 0ul; + + detail::clear_queue(finished_queue_); + + NNHistoryManager* hm = dynamic_cast(historyManager_); + hm->visit([&](HistoryHandle h) { + num_histories += 1ul; + ScoresWithContext* c = const_cast(reinterpret_cast(h)); + bool computed = c->computed.load(); + if (free_memory_ and computed and c->was_expanded and c->info.numStates == 0 and c->last_used < current_time_ - std::min(free_memory_delay_, current_time_)) { + c->nn_output->clear(); + c->computed.store(false); + } + else if (async_ and not computed and not(c->was_expanded and c->info.numStates == 0)) { + fwd_queue_.enqueue(new History(history(h))); + } + if (c->nn_output) { + nn_output_cache_size += c->nn_output->usedMemory(); + } + for (auto const& state_vec : c->state) { + if (state_vec) { + state_cache_size += state_vec->usedMemory(); + } + } + }); + + if (log_memory_ and statistics_.isOpen()) { + statistics_ << Core::XmlOpen("memory-usage") + Core::XmlAttribute("time-frame", current_time_); + statistics_ << Core::XmlOpen("nn-output-cache-size") + Core::XmlAttribute("unit", "MB") << (nn_output_cache_size / (1024. * 1024.)) << Core::XmlClose("nn-output-cache-size"); + statistics_ << Core::XmlOpen("state-cache-size") + Core::XmlAttribute("unit", "MB") << (state_cache_size / (1024. * 1024.)) << Core::XmlClose("state-cache-size"); + statistics_ << Core::XmlOpen("num-histories") << num_histories << Core::XmlClose("num-histories"); + statistics_ << Core::XmlClose("memory-usage"); + } + + auto timer_end = std::chrono::steady_clock::now(); + double start_frame_duration = std::chrono::duration(timer_end - timer_start).count(); + total_start_frame_time_ += start_frame_duration; +} + +template +void RecurrentLanguageModel::setInfo(History const& hist, SearchSpaceInformation const& info) const { + ScoresWithContext* sc = const_cast(reinterpret_cast(hist.handle())); + sc->info = info; + sc->last_info = current_time_; +} + +template +History RecurrentLanguageModel::extendHistoryWithOutputIdx(History const& hist, size_t w) const { + auto timer_start = std::chrono::steady_clock::now(); + NNHistoryManager* hm = dynamic_cast(historyManager_); + ScoresWithContext const* sc = reinterpret_cast(hist.handle()); + TokenIdSequence ts(*sc->history); + ts.push_back(w); + HistoryHandle h = hm->get(ts); + ScoresWithContext* cache = const_cast(reinterpret_cast(h)); + if (cache->parent.handle() == nullptr) { + cache->parent = hist; + ScoresWithContext* parent_cache = const_cast(reinterpret_cast(hist.handle())); + parent_cache->was_expanded = true; + if (async_) { + fwd_queue_.enqueue(new History(history(h))); + } + } + History ext_hist(history(h)); + if (cache->history->size() > history_pruning_threshold_) { + ext_hist = reducedHistory(ext_hist, pruned_history_length_); + } + auto timer_end = std::chrono::steady_clock::now(); + double expand_hist_time = std::chrono::duration(timer_end - timer_start).count(); + total_expand_hist_time_ += expand_hist_time; + return ext_hist; +} + +template +void RecurrentLanguageModel::background_forward() const { + while (not should_stop_) { + forward(to_fwd_.exchange(nullptr)); + } + History const* h = nullptr; + while (fwd_queue_.try_dequeue(h)) { + finished_queue_.enqueue(h); + } + for (History const* h : pending_) { + finished_queue_.enqueue(h); + } + pending_.clear(); +} + +template +template +void RecurrentLanguageModel::forward(Lm::History const* hist) const { + ScoresWithContext* sc = nullptr; + if (hist != nullptr) { + sc = const_cast(reinterpret_cast(hist->handle())); + } + if (async and sc != nullptr and sc->computed.load()) { // nothing to do (only happens in async case) + to_fwd_finished_.set_value(hist); + return; + } + auto start = std::chrono::steady_clock::now(); + + detail::RequestGraph request_graph; + if (sc != nullptr) { + request_graph.add_cache(const_cast(sc)); + } + + std::vector requests; + std::vector request_histories; // make sure none of the request caches go away while we compute the scores + size_t max_length = 0ul; + + size_t num_pending_requests = pending_.size(); + std::unordered_set handles; // only relevant in async case + handles.reserve(pending_.size()); + std::vector early_requests; + std::vector early_request_histories; // make sure none of the request caches go away while we compute the scores (only relevant in async case) + + if (async) { + auto process_hist = [&](History const* hist) { + ScoresWithContext* c = const_cast(reinterpret_cast(hist->handle())); + ScoresWithContext* parent_c = const_cast(reinterpret_cast(c->parent.handle())); + if (handles.find(hist->handle()) == handles.end() and not c->computed.load() and c != sc and c->parent.handle() != nullptr and c->ref_count > 1 and (not single_step_only_ or parent_c->computed.load())) { + early_requests.emplace_back(c); + early_request_histories.emplace_back(hist); + handles.insert(hist->handle()); + } + else { + finished_queue_.enqueue(hist); + } + }; + + std::for_each(pending_.begin(), pending_.end(), process_hist); + pending_.clear(); + + History const* hist_buf = nullptr; + bool success = false; + bool first = true; + do { + if (first) { + success = fwd_queue_.wait_dequeue_timed(hist_buf, 1000); + } + else { + success = fwd_queue_.try_dequeue(hist_buf); + } + if (success) { + process_hist(hist_buf); + first = false; + } + } while (success); + } + else { + NNHistoryManager* hm = dynamic_cast(historyManager_); + hm->visit([&](HistoryHandle h) { + ScoresWithContext* c = const_cast(reinterpret_cast(h)); + if (not c->computed.load() and c != sc and not(c->was_expanded and c->info.numStates == 0)) { + early_requests.emplace_back(c); + } + }); + } + + size_t num_early_requests = early_requests.size(); + + auto end_early_requests = std::chrono::steady_clock::now(); + + if (async and sc == nullptr and early_requests.empty()) { + // can only happen in async case + return; + } + + // because the scores can be updated while we are sorting we need to store them, so we get a consistent view + std::vector> idxs; + idxs.reserve(early_requests.size()); + for (size_t i = 0ul; i < early_requests.size(); i++) { + idxs.emplace_back(i, early_requests[i]->info.minLabelDistance * 1000 + early_requests[i]->info.bestScoreOffset); + } + std::sort(idxs.begin(), idxs.end(), [](std::pair const& a, std::pair const& b) { + return a.second < b.second; + }); + + for (auto idx : idxs) { + request_graph.add_cache(early_requests[idx.first]); + } + + // we do not need early_requests anymore + early_requests.clear(); + idxs.clear(); + + requests = request_graph.get_requests(); + + // prune requests + if (min_batch_size_ > 0ul and requests.size() > min_batch_size_) { + size_t i = min_batch_size_; + Score ref_score = requests.front().final_cache->info.bestScoreOffset + batch_pruning_threshold_; + if (not Math::isinf(ref_score)) { + while ((i + 1) < requests.size() and requests[i + 1].final_cache->info.bestScoreOffset <= ref_score) { + i += 1ul; + } + requests.resize(i); + } + } + + if (min_batch_size_ > 0ul and opt_batch_size_ > 0ul and requests.size() > opt_batch_size_ + min_batch_size_) { + requests.resize(opt_batch_size_); + } + if (max_batch_size_ > 0ul and requests.size() > max_batch_size_) { + requests.resize(max_batch_size_); + } + + Score worst_score = std::numeric_limits::min(); + for (auto const& r : requests) { + max_length = std::max(max_length, r.length); + worst_score = std::max(worst_score, r.final_cache->info.bestScoreOffset); + } + + auto end_requests = std::chrono::steady_clock::now(); + + // prepare the data in Sprint Datastructures + Math::FastMatrix words(requests.size(), max_length); + Math::FastVector word_lengths(requests.size()); + for (size_t r = 0ul; r < requests.size(); r++) { + auto& history = *(requests[r].final_cache->history); + size_t offset = history.size() - requests[r].length; + for (size_t w = 0u; w < requests[r].length; w++) { + words.at(r, w) = static_cast(history[offset + w]); + } + for (size_t w = requests[r].length; w < max_length; w++) { + words.at(r, w) = 0; + } + word_lengths[r] = requests[r].length; + ScoresWithContext* initial_cache = requests[r].initial_cache; + require(initial_cache != nullptr); + require_eq(state_variables_.size(), initial_cache->state.size()); + } + + bool full_prefix_required = state_manager_->requiresAllParentStates(); + size_t total_prefix_length = 0ul; + size_t total_suffix_length = 0ul; + + std::vector prefix_lengths(requests.size()); + std::vector suffix_lengths(requests.size()); + for (size_t r = 0ul; r < requests.size(); r++) { + prefix_lengths[r] = requests[r].initial_cache->history->size(); + suffix_lengths[r] = requests[r].length; + total_prefix_length += prefix_lengths[r]; + total_suffix_length += suffix_lengths[r]; + } + + std::vector::HistoryState const*> prefix_states(full_prefix_required ? total_prefix_length : requests.size()); + size_t current_offset = 0ul; + for (size_t r = 0ul; r < requests.size(); r++) { + ScoresWithContext* current_cache = requests[r].initial_cache; + if (full_prefix_required) { + size_t prefix_length = prefix_lengths[r]; + for (size_t i = 0ul; i < prefix_length; i++) { + prefix_states[current_offset + prefix_length - i - 1] = ¤t_cache->state; + current_cache = const_cast(reinterpret_cast(current_cache->parent.handle())); + } + current_offset += prefix_length; + } + else { + prefix_states[r] = ¤t_cache->state; + } + } + + auto end_prepare = std::chrono::steady_clock::now(); + + // build tensors + set state variables + std::vector> inputs; + std::vector targets; + state_manager_->mergeStates(state_variables_, prefix_lengths, prefix_states, inputs, targets); + + std::vector state_lengths(prefix_lengths.begin(), prefix_lengths.end()); + + if (dump_inputs_) { + std::string out = dump_inputs_prefix_ + "_" + std::to_string(dump_inputs_counter_) + "_state_"; + for (size_t i = 0ul; i < inputs.size(); i++) { + inputs[i].second.template save(out + std::to_string(i)); + } + } + + auto end_merge_state = std::chrono::steady_clock::now(); + + setState(inputs, targets); + + auto end_set_state = std::chrono::steady_clock::now(); + + extendInputs(inputs, words, word_lengths, state_lengths); + extendTargets(targets); + + std::vector outputs; + getOutputs(inputs, outputs, targets); + + if (dump_inputs_) { + std::string out = dump_inputs_prefix_ + "_" + std::to_string(dump_inputs_counter_) + "_nn_in_"; + for (size_t i = 0ul; i < inputs.size(); i++) { + inputs[i].second.template save(out + std::to_string(i)); + } + out = dump_inputs_prefix_ + "_" + std::to_string(dump_inputs_counter_) + "_nn_out_"; + for (size_t i = 0ul; i < outputs.size(); i++) { + outputs[i].template save(out + std::to_string(i)); + } + dump_inputs_counter_ += 1ul; + } + + auto end_nn_output = std::chrono::steady_clock::now(); + + // store outputs in caches + for (size_t r = 0ul; r < requests.size(); r++) { + ScoresWithContext* cache = requests[r].final_cache; + // only final cache get the states + for (size_t w = requests[r].length; w > 0;) { + --w; + cache->last_used = current_time_; + int num_outputs = outputs[0ul].dimSize(2); + auto compression_param_estimator = nn_output_comp_vec_factory_->getEstimator(); + float const* data = outputs[0ul].template data(r, w, 0); + compression_param_estimator->accumulate(data, num_outputs); + auto compression_params = compression_param_estimator->estimate(); + cache->nn_output = nn_output_comp_vec_factory_->compress(data, num_outputs, compression_params.get()); + cache->computed.store(true); + cache = const_cast(reinterpret_cast(cache->parent.handle())); + } + require_eq(cache, requests[r].initial_cache); + } + + auto end_set_nn_output = std::chrono::steady_clock::now(); + + // fetch new values of state variables, needs to be done in separate Session::run call (for GPU devices) + std::vector state_vars = fetchStates(outputs); + + auto end_get_new_state = std::chrono::steady_clock::now(); + + auto split_states = state_manager_->splitStates(state_variables_, suffix_lengths, state_vars, *state_comp_vec_factory_); + + size_t output_offset = 0ul; + for (size_t r = 0ul; r < requests.size(); r++) { + ScoresWithContext* current_cache = requests[r].final_cache; + size_t suffix_length = suffix_lengths[r]; + while (suffix_length > 0ul) { + current_cache->state = std::move(split_states[output_offset + suffix_length - 1]); + current_cache = const_cast(reinterpret_cast(current_cache->parent.handle())); + suffix_length -= 1ul; + } + output_offset += suffix_lengths[r]; + } + + auto end_split_state = std::chrono::steady_clock::now(); + + std::chrono::duration duration = end_split_state - end_prepare; + size_t bucket = requests.size() - 1; + run_time_.at(bucket) += duration.count(); + run_count_.at(bucket) += 1ul; + + if (dump_scores_) { + for (auto const& r : requests) { + detail::dump_scores(*r.final_cache, dump_scores_prefix_); + } + } + + if (async) { + for (auto hist : early_request_histories) { + ScoresWithContext* c = const_cast(reinterpret_cast(hist->handle())); + if (c->computed.load() or c->ref_count == 1ul or c->info.numStates == 0) { + finished_queue_.enqueue(hist); + } + else { + pending_.push_back(hist); + } + } + if (sc != nullptr) { + to_fwd_finished_.set_value(hist); + } + } + + auto end = std::chrono::steady_clock::now(); + + detail::TimeStatistics stats; + stats.total_duration = std::chrono::duration(end - start); + stats.early_request_duration = std::chrono::duration(end_early_requests - start); + stats.request_duration = std::chrono::duration(end_requests - end_early_requests); + stats.prepare_duration = std::chrono::duration(end_prepare - end_requests); + stats.merge_state_duration = std::chrono::duration(end_merge_state - end_prepare); + stats.set_state_duration = std::chrono::duration(end_set_state - end_merge_state); + stats.run_nn_output_duration = std::chrono::duration(end_nn_output - end_set_state); + stats.set_nn_output_duration = std::chrono::duration(end_set_nn_output - end_nn_output); + stats.get_new_state_duration = std::chrono::duration(end_get_new_state - end_set_nn_output); + stats.split_state_duration = std::chrono::duration(end_split_state - end_get_new_state); + stats.softmax_output_duration = std::chrono::duration(); + if (verbose_) { + stats.write(std::cerr); + std::cerr << " #pr:" << num_pending_requests + << " #er:" << num_early_requests + << " #r:" << requests.size() << std::endl; + } + fwd_statistics_ += stats; +} + +} // namespace Lm + +#endif // _LM_RECURRENT_LANGUAGE_MODEL_HH diff --git a/src/Lm/ReducedPrecisionCompressedVectorFactory.hh b/src/Lm/ReducedPrecisionCompressedVectorFactory.hh index a4732f8d6..21563e117 100644 --- a/src/Lm/ReducedPrecisionCompressedVectorFactory.hh +++ b/src/Lm/ReducedPrecisionCompressedVectorFactory.hh @@ -24,7 +24,8 @@ namespace Lm { class ReducedBitsFloatVector : public CompressedVector { public: ReducedBitsFloatVector(unsigned drop_bits) - : drop_bits_(drop_bits), bits_per_val_(sizeof(float) * 8 - drop_bits_) { + : drop_bits_(drop_bits), + bits_per_val_(sizeof(float) * 8 - drop_bits_) { } virtual size_t size() const; @@ -49,7 +50,8 @@ public: static const Core::ParameterInt paramDropBits; ReducedPrecisionCompressedVectorFactory(Core::Configuration const& config) - : Precursor(config), drop_bits_(paramDropBits(config)) {} + : Precursor(config), + drop_bits_(paramDropBits(config)) {} virtual ~ReducedPrecisionCompressedVectorFactory() = default; virtual CompressedVectorPtr compress(float const* data, size_t size, CompressionParameters const* params) const; diff --git a/src/Lm/StateManager.hh b/src/Lm/StateManager.hh deleted file mode 100644 index 0efb14143..000000000 --- a/src/Lm/StateManager.hh +++ /dev/null @@ -1,63 +0,0 @@ -/** Copyright 2020 RWTH Aachen University. All rights reserved. - * - * Licensed under the RWTH ASR License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef _LM_STATE_MANAGER_HH -#define _LM_STATE_MANAGER_HH - -#include -#include -#include - -#include "CompressedVector.hh" - -namespace Lm { - -class StateManager : public Core::Component { -public: - using Precursor = Core::Component; - using FeedDict = std::vector>; - using TargetList = std::vector; - using StateVariables = std::vector; - using HistoryState = std::vector>; - - StateManager(Core::Configuration const& config); - virtual ~StateManager() = default; - - virtual bool requiresAllParentStates() const; - - virtual HistoryState initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory) = 0; - virtual void mergeStates(StateVariables const& vars, - std::vector& prefix_lengths, - std::vector const& prefix_states, - FeedDict& feed_dict, - TargetList& targets) = 0; - virtual std::vector splitStates(StateVariables const& vars, - std::vector& suffix_lengths, - std::vector const& state_tensors, - CompressedVectorFactory const& vector_factory) = 0; -}; - -// inline implementations - -inline bool StateManager::requiresAllParentStates() const { - return false; -} - -inline StateManager::StateManager(Core::Configuration const& config) - : Precursor(config) { -} - -} // namespace Lm - -#endif // _LM_STATE_MANAGER_HH diff --git a/src/Lm/BlasNceSoftmaxAdapter.cc b/src/Lm/TFBlasNceSoftmaxAdapter.cc similarity index 81% rename from src/Lm/BlasNceSoftmaxAdapter.cc rename to src/Lm/TFBlasNceSoftmaxAdapter.cc index b8e188e3b..02a51f0c7 100644 --- a/src/Lm/BlasNceSoftmaxAdapter.cc +++ b/src/Lm/TFBlasNceSoftmaxAdapter.cc @@ -12,7 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "BlasNceSoftmaxAdapter.hh" +#include "TFBlasNceSoftmaxAdapter.hh" #include @@ -20,13 +20,13 @@ namespace Lm { -void BlasNceSoftmaxAdapter::init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) { +void TFBlasNceSoftmaxAdapter::init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) { auto const& weight_tensor_info = output_map.get_info("weights"); auto const& bias_tensor_info = output_map.get_info("bias"); session.run({}, {weight_tensor_info.tensor_name(), bias_tensor_info.tensor_name()}, {}, tensors_); } -Score BlasNceSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { +Score TFBlasNceSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { std::vector nn_output; float const* data; Lm::UncompressedVector const* vec = dynamic_cast const*>(nn_out.get()); @@ -46,7 +46,7 @@ Score BlasNceSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_ return result; } -std::vector BlasNceSoftmaxAdapter::get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs) { +std::vector TFBlasNceSoftmaxAdapter::get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs) { std::vector nn_output(nn_out->size()); nn_out->uncompress(nn_output.data(), nn_output.size()); diff --git a/src/Lm/BlasNceSoftmaxAdapter.hh b/src/Lm/TFBlasNceSoftmaxAdapter.hh similarity index 79% rename from src/Lm/BlasNceSoftmaxAdapter.hh rename to src/Lm/TFBlasNceSoftmaxAdapter.hh index c62de4e56..226adbb36 100644 --- a/src/Lm/BlasNceSoftmaxAdapter.hh +++ b/src/Lm/TFBlasNceSoftmaxAdapter.hh @@ -15,16 +15,16 @@ #ifndef _LM_BLAS_NCE_SOFTMAX_ADAPTER_HH #define _LM_BLAS_NCE_SOFTMAX_ADAPTER_HH -#include "SoftmaxAdapter.hh" +#include "TFSoftmaxAdapter.hh" namespace Lm { -class BlasNceSoftmaxAdapter : public SoftmaxAdapter { +class TFBlasNceSoftmaxAdapter : public TFSoftmaxAdapter { public: - using Precursor = SoftmaxAdapter; + using Precursor = TFSoftmaxAdapter; - BlasNceSoftmaxAdapter(Core::Configuration const& config); - virtual ~BlasNceSoftmaxAdapter() = default; + TFBlasNceSoftmaxAdapter(Core::Configuration const& config); + virtual ~TFBlasNceSoftmaxAdapter() = default; virtual void init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map); virtual Score get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx); @@ -36,7 +36,7 @@ private: // inline implementations -inline BlasNceSoftmaxAdapter::BlasNceSoftmaxAdapter(Core::Configuration const& config) +inline TFBlasNceSoftmaxAdapter::TFBlasNceSoftmaxAdapter(Core::Configuration const& config) : Precursor(config) { } diff --git a/src/Lm/TFLstmStateManager.cc b/src/Lm/TFLstmStateManager.cc new file mode 100644 index 000000000..8b3f258dd --- /dev/null +++ b/src/Lm/TFLstmStateManager.cc @@ -0,0 +1,14 @@ + +#include "TFLstmStateManager.hh" + +namespace Lm { + +void TFLstmStateManager::extendFeedDict(FeedDict& feed_dict, Tensorflow::Variable const& state_var, Tensorflow::Tensor& var) { + feed_dict.emplace_back(state_var.initial_value_name, var); +} + +void TFLstmStateManager::extendTargets(TargetList& targets, Tensorflow::Variable const& state_var) { + targets.emplace_back(state_var.initializer_name); +} + +} // namespace Lm diff --git a/src/Lm/TFLstmStateManager.hh b/src/Lm/TFLstmStateManager.hh new file mode 100644 index 000000000..342920870 --- /dev/null +++ b/src/Lm/TFLstmStateManager.hh @@ -0,0 +1,31 @@ +#ifndef _LM_TF_LSTM_STATE_MANAGER_HH +#define _LM_TF_LSTM_STATE_MANAGER_HH + +#include +#include + +#include "LstmStateManager.hh" + +namespace Lm { + +class TFLstmStateManager : public LstmStateManager { +public: + using Precursor = LstmStateManager; + + TFLstmStateManager(Core::Configuration const& config); + virtual ~TFLstmStateManager() = default; + +protected: + virtual void extendFeedDict(FeedDict& feed_dict, Tensorflow::Variable const& state_var, Tensorflow::Tensor& var); + virtual void extendTargets(TargetList& targets, Tensorflow::Variable const& state_var); +}; + +// inline implementations + +inline TFLstmStateManager::TFLstmStateManager(Core::Configuration const& config) + : Precursor(config) { +} + +} // namespace Lm + +#endif // _LM_TF_LSTM_STATE_MANAGER_HH diff --git a/src/Lm/NceSoftmaxAdapter.cc b/src/Lm/TFNceSoftmaxAdapter.cc similarity index 82% rename from src/Lm/NceSoftmaxAdapter.cc rename to src/Lm/TFNceSoftmaxAdapter.cc index 948a5c2c9..20758ce45 100644 --- a/src/Lm/NceSoftmaxAdapter.cc +++ b/src/Lm/TFNceSoftmaxAdapter.cc @@ -12,22 +12,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "NceSoftmaxAdapter.hh" +#include "TFNceSoftmaxAdapter.hh" namespace Lm { -void NceSoftmaxAdapter::init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) { +void TFNceSoftmaxAdapter::init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) { session_ = &session; input_map_ = &input_map; output_map_ = &output_map; } -Score NceSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { +Score TFNceSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { std::vector output_idxs(1, output_idx); return get_scores(nn_out, output_idxs)[0]; } -std::vector NceSoftmaxAdapter::get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs) { +std::vector TFNceSoftmaxAdapter::get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs) { auto const& output_idx_tensor_info = input_map_->get_info("output_idxs"); auto const& nn_output_tensor_info = input_map_->get_info("nn_output"); auto const& softmax_tensor_info = output_map_->get_info("nce_softmax"); diff --git a/src/Lm/NceSoftmaxAdapter.hh b/src/Lm/TFNceSoftmaxAdapter.hh similarity index 81% rename from src/Lm/NceSoftmaxAdapter.hh rename to src/Lm/TFNceSoftmaxAdapter.hh index 68d686277..6e6f23ea8 100644 --- a/src/Lm/NceSoftmaxAdapter.hh +++ b/src/Lm/TFNceSoftmaxAdapter.hh @@ -15,16 +15,16 @@ #ifndef _LM_NCE_SOFTMAX_ADAPTER_HH #define _LM_NCE_SOFTMAX_ADAPTER_HH -#include "SoftmaxAdapter.hh" +#include "TFSoftmaxAdapter.hh" namespace Lm { -class NceSoftmaxAdapter : public SoftmaxAdapter { +class TFNceSoftmaxAdapter : public TFSoftmaxAdapter { public: - using Precursor = SoftmaxAdapter; + using Precursor = TFSoftmaxAdapter; - NceSoftmaxAdapter(Core::Configuration const& config); - virtual ~NceSoftmaxAdapter() = default; + TFNceSoftmaxAdapter(Core::Configuration const& config); + virtual ~TFNceSoftmaxAdapter() = default; virtual void init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map); virtual Score get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx); @@ -38,7 +38,7 @@ private: // inline implementations -inline NceSoftmaxAdapter::NceSoftmaxAdapter(Core::Configuration const& config) +inline TFNceSoftmaxAdapter::TFNceSoftmaxAdapter(Core::Configuration const& config) : Precursor(config) { } diff --git a/src/Lm/PassthroughSoftmaxAdapter.cc b/src/Lm/TFPassthroughSoftmaxAdapter.cc similarity index 69% rename from src/Lm/PassthroughSoftmaxAdapter.cc rename to src/Lm/TFPassthroughSoftmaxAdapter.cc index 565975c9f..c50a412ff 100644 --- a/src/Lm/PassthroughSoftmaxAdapter.cc +++ b/src/Lm/TFPassthroughSoftmaxAdapter.cc @@ -12,14 +12,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "PassthroughSoftmaxAdapter.hh" +#include "TFPassthroughSoftmaxAdapter.hh" namespace Lm { -void PassthroughSoftmaxAdapter::init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) { +void TFPassthroughSoftmaxAdapter::init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) { } -Score PassthroughSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { +Score TFPassthroughSoftmaxAdapter::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { return nn_out->get(output_idx); } diff --git a/src/Lm/PassthroughSoftmaxAdapter.hh b/src/Lm/TFPassthroughSoftmaxAdapter.hh similarity index 76% rename from src/Lm/PassthroughSoftmaxAdapter.hh rename to src/Lm/TFPassthroughSoftmaxAdapter.hh index 031666662..a65f4743c 100644 --- a/src/Lm/PassthroughSoftmaxAdapter.hh +++ b/src/Lm/TFPassthroughSoftmaxAdapter.hh @@ -15,16 +15,16 @@ #ifndef _LM_PASSTHROUGH_SOFTMAX_ADAPTER_HH #define _LM_PASSTHROUGH_SOFTMAX_ADAPTER_HH -#include "SoftmaxAdapter.hh" +#include "TFSoftmaxAdapter.hh" namespace Lm { -class PassthroughSoftmaxAdapter : public SoftmaxAdapter { +class TFPassthroughSoftmaxAdapter : public TFSoftmaxAdapter { public: - using Precursor = SoftmaxAdapter; + using Precursor = TFSoftmaxAdapter; - PassthroughSoftmaxAdapter(Core::Configuration const& config); - virtual ~PassthroughSoftmaxAdapter() = default; + TFPassthroughSoftmaxAdapter(Core::Configuration const& config); + virtual ~TFPassthroughSoftmaxAdapter() = default; virtual void init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map); virtual Score get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx); @@ -34,7 +34,7 @@ private: // inline implementations -inline PassthroughSoftmaxAdapter::PassthroughSoftmaxAdapter(Core::Configuration const& config) +inline TFPassthroughSoftmaxAdapter::TFPassthroughSoftmaxAdapter(Core::Configuration const& config) : Precursor(config) { } diff --git a/src/Lm/QuantizedBlasNceSoftmaxAdapter.cc b/src/Lm/TFQuantizedBlasNceSoftmaxAdapter.cc similarity index 90% rename from src/Lm/QuantizedBlasNceSoftmaxAdapter.cc rename to src/Lm/TFQuantizedBlasNceSoftmaxAdapter.cc index 5dabe35b4..c9b672bcf 100644 --- a/src/Lm/QuantizedBlasNceSoftmaxAdapter.cc +++ b/src/Lm/TFQuantizedBlasNceSoftmaxAdapter.cc @@ -12,7 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "QuantizedBlasNceSoftmaxAdapter.hh" +#include "TFQuantizedBlasNceSoftmaxAdapter.hh" #include @@ -50,19 +50,19 @@ float quantized_dot_16bit(size_t size, float scale, s16 const* a, s16 const* b) namespace Lm { template<> -const Core::ParameterFloat QuantizedBlasNceSoftmaxAdapter16Bit::paramNNOutputEpsilon( +const Core::ParameterFloat TFQuantizedBlasNceSoftmaxAdapter16Bit::paramNNOutputEpsilon( "nn-output-epsilon", "if the nn-output vector is not quantized, use this scale for quantization", 0.001, 0.0); template<> -const Core::ParameterFloat QuantizedBlasNceSoftmaxAdapter16Bit::paramWeightsBiasEpsilon( +const Core::ParameterFloat TFQuantizedBlasNceSoftmaxAdapter16Bit::paramWeightsBiasEpsilon( "weights-bias-epsilon", "if the nn-output vector is not quantized, use this scale for quantization", 0.001, 0.0); template<> -Score QuantizedBlasNceSoftmaxAdapter16Bit::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { +Score TFQuantizedBlasNceSoftmaxAdapter16Bit::get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) { std::vector nn_output; s16 const* data; float scale; diff --git a/src/Lm/QuantizedBlasNceSoftmaxAdapter.hh b/src/Lm/TFQuantizedBlasNceSoftmaxAdapter.hh similarity index 74% rename from src/Lm/QuantizedBlasNceSoftmaxAdapter.hh rename to src/Lm/TFQuantizedBlasNceSoftmaxAdapter.hh index afc03aeaa..21465469b 100644 --- a/src/Lm/QuantizedBlasNceSoftmaxAdapter.hh +++ b/src/Lm/TFQuantizedBlasNceSoftmaxAdapter.hh @@ -18,20 +18,20 @@ #include #include -#include "SoftmaxAdapter.hh" +#include "TFSoftmaxAdapter.hh" namespace Lm { template -class QuantizedBlasNceSoftmaxAdapter : public SoftmaxAdapter { +class TFQuantizedBlasNceSoftmaxAdapter : public TFSoftmaxAdapter { public: - using Precursor = SoftmaxAdapter; + using Precursor = TFSoftmaxAdapter; static const Core::ParameterFloat paramNNOutputEpsilon; static const Core::ParameterFloat paramWeightsBiasEpsilon; - QuantizedBlasNceSoftmaxAdapter(Core::Configuration const& config); - virtual ~QuantizedBlasNceSoftmaxAdapter() = default; + TFQuantizedBlasNceSoftmaxAdapter(Core::Configuration const& config); + virtual ~TFQuantizedBlasNceSoftmaxAdapter() = default; virtual void init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map); virtual Score get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx); @@ -44,18 +44,20 @@ private: Math::FastVector bias_; }; -using QuantizedBlasNceSoftmaxAdapter16Bit = QuantizedBlasNceSoftmaxAdapter; -using QuantizedBlasNceSoftmaxAdapter8Bit = QuantizedBlasNceSoftmaxAdapter; +using TFQuantizedBlasNceSoftmaxAdapter16Bit = TFQuantizedBlasNceSoftmaxAdapter; +using TFQuantizedBlasNceSoftmaxAdapter8Bit = TFQuantizedBlasNceSoftmaxAdapter; // inline implementations template -inline QuantizedBlasNceSoftmaxAdapter::QuantizedBlasNceSoftmaxAdapter(Core::Configuration const& config) - : Precursor(config), nnOutputEpsilon_(paramNNOutputEpsilon(config)), weightsBiasEpsilon_(paramWeightsBiasEpsilon(config)) { +inline TFQuantizedBlasNceSoftmaxAdapter::TFQuantizedBlasNceSoftmaxAdapter(Core::Configuration const& config) + : Precursor(config), + nnOutputEpsilon_(paramNNOutputEpsilon(config)), + weightsBiasEpsilon_(paramWeightsBiasEpsilon(config)) { } template -inline void QuantizedBlasNceSoftmaxAdapter::init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) { +inline void TFQuantizedBlasNceSoftmaxAdapter::init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) { auto const& weight_tensor_info = output_map.get_info("weights"); auto const& bias_tensor_info = output_map.get_info("bias"); std::vector tensors; diff --git a/src/Lm/TFRecurrentLanguageModel.cc b/src/Lm/TFRecurrentLanguageModel.cc index b465599fb..4b759fd47 100644 --- a/src/Lm/TFRecurrentLanguageModel.cc +++ b/src/Lm/TFRecurrentLanguageModel.cc @@ -1,145 +1,12 @@ -/** Copyright 2020 RWTH Aachen University. All rights reserved. - * - * Licensed under the RWTH ASR License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ #include "TFRecurrentLanguageModel.hh" -#include - -#include "BlasNceSoftmaxAdapter.hh" -#include "LstmStateManager.hh" -#include "Module.hh" -#include "NceSoftmaxAdapter.hh" -#include "PassthroughSoftmaxAdapter.hh" -#include "QuantizedBlasNceSoftmaxAdapter.hh" -#include "TransformerStateManager.hh" - -namespace { -struct ScoresWithContext : public Lm::NNCacheWithStats { - virtual ~ScoresWithContext() = default; - - std::atomic computed; - Lm::History parent; - Lm::CompressedVectorPtr nn_output; - std::vector> state; - Lm::SearchSpaceInformation info; - Search::TimeframeIndex last_used; - Search::TimeframeIndex last_info; - bool was_expanded; -}; - -struct FwdRequest { - ScoresWithContext* initial_cache; - ScoresWithContext* final_cache; - size_t length; - - bool operator==(FwdRequest const& other) const { - return final_cache == other.final_cache; - } -}; - -struct RequestGraph { - std::vector entries; - std::vector> children; - std::vector roots; - - void add_cache(ScoresWithContext* cache) { - std::vector request_chain; - request_chain.push_back(cache); - ScoresWithContext* parent = const_cast(reinterpret_cast(cache->parent.handle())); - request_chain.push_back(parent); - while (parent->state.empty()) { - parent = const_cast(reinterpret_cast(parent->parent.handle())); - request_chain.push_back(parent); - } - - std::vector* child_idxs = &roots; - while (not request_chain.empty()) { - // find root node - size_t child_idx = child_idxs->size(); - for (size_t c = 0ul; c < child_idxs->size(); c++) { - if (entries[child_idxs->at(c)] == request_chain.back()) { - child_idx = c; - break; - } - } - size_t next_child_idx = 0ul; - if (child_idx == child_idxs->size()) { - child_idxs->push_back(entries.size()); - entries.push_back(request_chain.back()); - next_child_idx = child_idxs->at(child_idx); - children.emplace_back(); // can invalidate child_idxs - } - else { - next_child_idx = child_idxs->at(child_idx); - } - child_idxs = &children[next_child_idx]; - request_chain.pop_back(); - } - } - - void get_requests_dfs(std::vector& requests, ScoresWithContext* initial, size_t entry, size_t length) { - if (children[entry].empty()) { - requests.emplace_back(FwdRequest{initial, entries[entry], length}); - } - else { - for (size_t e : children[entry]) { - get_requests_dfs(requests, initial, e, length + 1ul); - } - } - } - - std::vector get_requests() { - std::vector result; - for (size_t r : roots) { - for (size_t c : children[r]) { - get_requests_dfs(result, entries[r], c, 1ul); - } - } - return result; - } -}; - -void dump_scores(ScoresWithContext const& cache, std::string const& prefix) { - std::stringstream path; - path << prefix; - for (auto token : *cache.history) { - path << "_" << token; - } - std::ofstream out(path.str(), std::ios::out | std::ios::trunc); - out << "nn_output:\n"; - std::vector nn_output(cache.nn_output->size()); - cache.nn_output->uncompress(nn_output.data(), nn_output.size()); - for (auto nn_out : nn_output) { - out << nn_out << '\n'; - } - for (size_t s = 0ul; s < cache.state.size(); s++) { - out << "state " << s << ":\n"; - std::vector state_data(cache.state[s]->size()); - cache.state[s]->uncompress(state_data.data(), state_data.size()); - for (auto v : state_data) { - out << v << '\n'; - } - } -} - -void clear_queue(Lm::TFRecurrentLanguageModel::HistoryQueue& queue) { - Lm::History const* hist = nullptr; - while (queue.try_dequeue(hist)) { - delete hist; - } -} -} // namespace +#include "TFBlasNceSoftmaxAdapter.hh" +#include "TFLstmStateManager.hh" +#include "TFNceSoftmaxAdapter.hh" +#include "TFPassthroughSoftmaxAdapter.hh" +#include "TFQuantizedBlasNceSoftmaxAdapter.hh" +#include "TFSoftmaxAdapter.hh" +#include "TFTransformerStateManager.hh" namespace Lm { @@ -168,187 +35,66 @@ const Core::ParameterChoice stateManagerTypeParam( "type of the state manager", LstmStateManagerType); -std::unique_ptr createStateManager(Core::Configuration const& config) { - StateManager* res = nullptr; +std::unique_ptr createStateManager(Core::Configuration const& config) { + TFStateManager* res = nullptr; switch (stateManagerTypeParam(config)) { - case LstmStateManagerType: res = new Lm::LstmStateManager(config); break; - case TransformerStateManagerType: res = new Lm::TransformerStateManager(config); break; - case TransformerStateManager16BitType: res = new Lm::TransformerStateManager(config); break; - case TransformerStateManager8BitType: res = new Lm::TransformerStateManager(config); break; - case TransformerStateManagerWithCommonPrefixType: res = new Lm::TransformerStateManagerWithCommonPrefix(config); break; - case TransformerStateManagerWithCommonPrefix16BitType: res = new Lm::TransformerStateManagerWithCommonPrefix(config); break; - case TransformerStateManagerWithCommonPrefix8BitType: res = new Lm::TransformerStateManagerWithCommonPrefix(config); break; + case LstmStateManagerType: res = new Lm::TFLstmStateManager(config); break; + case TransformerStateManagerType: res = new Lm::TFTransformerStateManager(config); break; + case TransformerStateManager16BitType: res = new Lm::TFTransformerStateManager(config); break; + case TransformerStateManager8BitType: res = new Lm::TFTransformerStateManager(config); break; + case TransformerStateManagerWithCommonPrefixType: res = new Lm::TFTransformerStateManagerWithCommonPrefix(config); break; + case TransformerStateManagerWithCommonPrefix16BitType: res = new Lm::TFTransformerStateManagerWithCommonPrefix(config); break; + case TransformerStateManagerWithCommonPrefix8BitType: res = new Lm::TFTransformerStateManagerWithCommonPrefix(config); break; default: defect(); } - return std::unique_ptr(res); + return std::unique_ptr(res); } enum SoftmaxAdapterType { - BlasNceSoftmaxAdapterType, - NceSoftmaxAdapterType, - PassthroughSoftmaxAdapterType, - QuantizedBlasNceSoftmaxAdapter16BitType + TFBlasNceSoftmaxAdapterType, + TFNceSoftmaxAdapterType, + TFPassthroughSoftmaxAdapterType, + TFQuantizedBlasNceSoftmaxAdapter16BitType }; const Core::Choice softmaxAdapterTypeChoice( - "blas_nce", BlasNceSoftmaxAdapterType, // included for backward compatibility - "blas-nce", BlasNceSoftmaxAdapterType, // more consistent with RASR conventions - "nce", NceSoftmaxAdapterType, - "passthrough", PassthroughSoftmaxAdapterType, - "quantized-blas-nce-16bit", QuantizedBlasNceSoftmaxAdapter16BitType, + "blas_nce", TFBlasNceSoftmaxAdapterType, // included for backward compatibility + "blas-nce", TFBlasNceSoftmaxAdapterType, // more consistent with RASR conventions + "nce", TFNceSoftmaxAdapterType, + "passthrough", TFPassthroughSoftmaxAdapterType, + "quantized-blas-nce-16bit", TFQuantizedBlasNceSoftmaxAdapter16BitType, Core::Choice::endMark()); const Core::ParameterChoice softmaxAdapterTypeParam( "type", &softmaxAdapterTypeChoice, "type of the softmax adapter", - PassthroughSoftmaxAdapterType); + TFPassthroughSoftmaxAdapterType); -std::unique_ptr createSoftmaxAdapter(Core::Configuration const& config) { +std::unique_ptr createSoftmaxAdapter(Core::Configuration const& config) { switch (softmaxAdapterTypeParam(config)) { - case BlasNceSoftmaxAdapterType: return std::unique_ptr(new Lm::BlasNceSoftmaxAdapter(config)); - case NceSoftmaxAdapterType: return std::unique_ptr(new Lm::NceSoftmaxAdapter(config)); - case PassthroughSoftmaxAdapterType: return std::unique_ptr(new Lm::PassthroughSoftmaxAdapter(config)); - case QuantizedBlasNceSoftmaxAdapter16BitType: return std::unique_ptr(new Lm::QuantizedBlasNceSoftmaxAdapter16Bit(config)); + case TFBlasNceSoftmaxAdapterType: return std::unique_ptr(new Lm::TFBlasNceSoftmaxAdapter(config)); + case TFNceSoftmaxAdapterType: return std::unique_ptr(new Lm::TFNceSoftmaxAdapter(config)); + case TFPassthroughSoftmaxAdapterType: return std::unique_ptr(new Lm::TFPassthroughSoftmaxAdapter(config)); + case TFQuantizedBlasNceSoftmaxAdapter16BitType: return std::unique_ptr(new Lm::TFQuantizedBlasNceSoftmaxAdapter16Bit(config)); default: defect(); } } -TFRecurrentLanguageModel::TimeStatistics TFRecurrentLanguageModel::TimeStatistics::operator+(TimeStatistics const& other) const { - TimeStatistics res; - - res.total_duration = total_duration + other.total_duration; - res.early_request_duration = early_request_duration + other.early_request_duration; - res.request_duration = request_duration + other.request_duration; - res.prepare_duration = prepare_duration + other.prepare_duration; - res.merge_state_duration = merge_state_duration + other.merge_state_duration; - res.set_state_duration = set_state_duration + other.set_state_duration; - res.run_nn_output_duration = run_nn_output_duration + other.run_nn_output_duration; - res.set_nn_output_duration = set_nn_output_duration + other.set_nn_output_duration; - res.get_new_state_duration = get_new_state_duration + other.get_new_state_duration; - res.split_state_duration = split_state_duration + other.split_state_duration; - res.softmax_output_duration = softmax_output_duration + other.softmax_output_duration; - - return res; -} - -TFRecurrentLanguageModel::TimeStatistics& TFRecurrentLanguageModel::TimeStatistics::operator+=(TimeStatistics const& other) { - total_duration += other.total_duration; - early_request_duration += other.early_request_duration; - request_duration += other.request_duration; - prepare_duration += other.prepare_duration; - merge_state_duration += other.merge_state_duration; - set_state_duration += other.set_state_duration; - run_nn_output_duration += other.run_nn_output_duration; - set_nn_output_duration += other.set_nn_output_duration; - get_new_state_duration += other.get_new_state_duration; - split_state_duration += other.split_state_duration; - softmax_output_duration += other.softmax_output_duration; - - return *this; -} - -void TFRecurrentLanguageModel::TimeStatistics::write(Core::XmlChannel& channel) const { - channel << Core::XmlOpen("total-duration") + Core::XmlAttribute("unit", "milliseconds") << total_duration.count() << Core::XmlClose("total-duration"); - channel << Core::XmlOpen("early-request-duration") + Core::XmlAttribute("unit", "milliseconds") << early_request_duration.count() << Core::XmlClose("early-request-duration"); - channel << Core::XmlOpen("request-duration") + Core::XmlAttribute("unit", "milliseconds") << request_duration.count() << Core::XmlClose("request-duration"); - channel << Core::XmlOpen("prepare-duration") + Core::XmlAttribute("unit", "milliseconds") << prepare_duration.count() << Core::XmlClose("prepare-duration"); - channel << Core::XmlOpen("merge-state-duration") + Core::XmlAttribute("unit", "milliseconds") << merge_state_duration.count() << Core::XmlClose("merge-state-duration"); - channel << Core::XmlOpen("set-state-duration") + Core::XmlAttribute("unit", "milliseconds") << set_state_duration.count() << Core::XmlClose("set-state-duration"); - channel << Core::XmlOpen("run-nn-output-duration") + Core::XmlAttribute("unit", "milliseconds") << run_nn_output_duration.count() << Core::XmlClose("run-nn-output-duration"); - channel << Core::XmlOpen("set-nn-output-duration") + Core::XmlAttribute("unit", "milliseconds") << set_nn_output_duration.count() << Core::XmlClose("set-nn-output-duration"); - channel << Core::XmlOpen("get-new-state-duration") + Core::XmlAttribute("unit", "milliseconds") << get_new_state_duration.count() << Core::XmlClose("get-new-state-duration"); - channel << Core::XmlOpen("split-state-duration") + Core::XmlAttribute("unit", "milliseconds") << split_state_duration.count() << Core::XmlClose("split-state-duration"); - channel << Core::XmlOpen("softmax-output-duration") + Core::XmlAttribute("unit", "milliseconds") << softmax_output_duration.count() << Core::XmlClose("softmax-output-duration"); -} - -void TFRecurrentLanguageModel::TimeStatistics::write(std::ostream& out) const { - out << "fwd: " << total_duration.count() - << " er:" << early_request_duration.count() - << " r:" << request_duration.count() - << " p:" << prepare_duration.count() - << " ms: " << merge_state_duration.count() - << " sst:" << set_state_duration.count() - << " rs:" << run_nn_output_duration.count() - << " sno:" << set_nn_output_duration.count() - << " gns:" << get_new_state_duration.count() - << " ss: " << split_state_duration.count() - << " smo:" << softmax_output_duration.count(); -} - -const Core::ParameterBool TFRecurrentLanguageModel::paramTransformOuputLog("transform-output-log", "apply log to tensorflow output", false); -const Core::ParameterBool TFRecurrentLanguageModel::paramTransformOuputNegate("transform-output-negate", "negate tensorflow output (after log)", false); -const Core::ParameterInt TFRecurrentLanguageModel::paramMinBatchSize("min-batch-size", "minimum number of histories forwarded in one go", 32); -const Core::ParameterInt TFRecurrentLanguageModel::paramOptBatchSize("opt-batch-size", "optimum number of histories forwarded in one go", 128); -const Core::ParameterInt TFRecurrentLanguageModel::paramMaxBatchSize("max-batch-size", "maximum number of histories forwarded in one go", 2048); -const Core::ParameterInt TFRecurrentLanguageModel::paramHistoryPruningThreshold("history-pruning-threshold", "if the history is longer than this parameter it will be pruned", std::numeric_limits::max(), 0); -const Core::ParameterInt TFRecurrentLanguageModel::paramPrunedHistoryLength("pruned-history-length", "length of the pruned history (should be smaller than history-pruning-threshold)", std::numeric_limits::max(), 0); -const Core::ParameterFloat TFRecurrentLanguageModel::paramBatchPruningThreshold("batch-pruning-threshold", "pruning threshold for all hypothesis beyond min-batch-size during eager forwarding", 10.0); -const Core::ParameterBool TFRecurrentLanguageModel::paramAllowReducedHistory("allow-reduced-history", "wether this LM will actually reduce the history length", false); -const Core::ParameterBool TFRecurrentLanguageModel::paramDumpInputs("dump-inputs", "write all inputs from this LM to disk", false); -const Core::ParameterString TFRecurrentLanguageModel::paramDumpInputsPrefix("dump-inputs-prefix", "prefix for the input dumps", "inputs"); -const Core::ParameterBool TFRecurrentLanguageModel::paramDumpScores("dump-scores", "write all scores from this LM to disk", false); -const Core::ParameterString TFRecurrentLanguageModel::paramDumpScoresPrefix("dump-scores-prefix", "prefix for the score dumps", "scores"); -const Core::ParameterBool TFRecurrentLanguageModel::paramLogMemory("log-memory", "wether memory usage from nn-outputs / states should be logged", false); -const Core::ParameterBool TFRecurrentLanguageModel::paramFreeMemory("free-memory", "wether nn-outputs should be deleted after some delay", false); -const Core::ParameterInt TFRecurrentLanguageModel::paramFreeMemoryDelay("free-memory-delay", "how many time frames without usage before nn-outputs are deleted", 40); -const Core::ParameterBool TFRecurrentLanguageModel::paramAsync("async", "wether to forward histories in a separate thread", false); -const Core::ParameterBool TFRecurrentLanguageModel::paramSingleStepOnly("single-step-only", "workaround for some bug that results in wrong scores when recombination is done in combination with async evaluation", false); -const Core::ParameterBool TFRecurrentLanguageModel::paramVerbose("verbose", "wether to print detailed statistics to stderr", false); - TFRecurrentLanguageModel::TFRecurrentLanguageModel(Core::Configuration const& c, Bliss::LexiconRef l) : Core::Component(c), - Precursor(c, l), - transform_output_log_(paramTransformOuputLog(config)), - transform_output_negate_(paramTransformOuputNegate(config)), - min_batch_size_(paramMinBatchSize(config)), - opt_batch_size_(paramOptBatchSize(config)), - max_batch_size_(paramMaxBatchSize(config)), - history_pruning_threshold_(paramHistoryPruningThreshold(config)), - pruned_history_length_(paramPrunedHistoryLength(config)), - batch_pruning_threshold_(paramBatchPruningThreshold(config)), - allow_reduced_history_(paramAllowReducedHistory(config)), - dump_inputs_(paramDumpInputs(config)), - dump_inputs_prefix_(paramDumpInputsPrefix(config)), - dump_scores_(paramDumpScores(config)), - dump_scores_prefix_(paramDumpScoresPrefix(config)), - log_memory_(paramLogMemory(config)), - free_memory_(paramFreeMemory(config)), - free_memory_delay_(paramFreeMemoryDelay(config)), - async_(paramAsync(config)), - single_step_only_(paramSingleStepOnly(config)), - verbose_(paramVerbose(config)), + Precursor(c, l, createStateManager(select("state-manager"))), session_(select("session")), loader_(Tensorflow::Module::instance().createGraphLoader(select("loader"))), graph_(loader_->load_graph()), tensor_input_map_(select("input-map")), tensor_output_map_(select("output-map")), - state_comp_vec_factory_(Lm::Module::instance().createCompressedVectorFactory(select("state-compression"))), - nn_output_comp_vec_factory_(Lm::Module::instance().createCompressedVectorFactory(select("nn-output-compression"))), - state_manager_(createStateManager(select("state-manager"))), - softmax_adapter_(createSoftmaxAdapter(select("softmax-adapter"))), - statistics_(config, "statistics"), - current_time_(0u), - run_time_(max_batch_size_, 0.0), - run_count_(max_batch_size_, 0ul), - total_wait_time_(0.0), - total_start_frame_time_(0.0), - total_expand_hist_time_(0.0), - fwd_statistics_(), - dump_inputs_counter_(0ul), - background_forwarder_thread_(), - should_stop_(false), - to_fwd_(nullptr), - to_fwd_finished_(), - pending_(), - fwd_queue_(32768), - finished_queue_(32768) { - require_le(pruned_history_length_, history_pruning_threshold_); + softmax_adapter_(createSoftmaxAdapter(select("softmax-adapter"))) { session_.addGraph(*graph_); loader_->initialize(session_); auto const& softmax_info = tensor_output_map_.get_info("softmax"); output_tensor_names_.push_back(softmax_info.tensor_name()); - state_variables_.reserve(state_variables_.size()); + state_variables_.reserve(graph_->state_vars().size()); for (std::string const& s : graph_->state_vars()) { auto const& var = graph_->variables().find(s)->second; state_variables_.emplace_back(var); @@ -356,443 +102,20 @@ TFRecurrentLanguageModel::TFRecurrentLanguageModel(Core::Configuration const& c, read_vars_tensor_names_.push_back(var.snapshot_name); } - if (transform_output_log_ and transform_output_negate_) { - output_transform_function_ = [](Score v) { return -std::log(v); }; - } - else if (transform_output_log_) { - output_transform_function_ = [](Score v) { return std::log(v); }; - } - else if (transform_output_negate_) { - output_transform_function_ = [](Score v) { return -v; }; + if (state_variables_.empty()) { + error("No recurrent state variables found in tensorflow graph of recurrent language model."); } - NNHistoryManager* hm = dynamic_cast(historyManager_); - TokenIdSequence ts; - HistoryHandle h = hm->get(ts); - ScoresWithContext* cache = const_cast(reinterpret_cast(h)); - cache->state = state_manager_->initialState(state_variables_, *state_comp_vec_factory_); - - if (cache->state.empty()) { - error("LM has no state variables. Did you forget to compile with 'initial_state': 'keep_over_epoch_no_init'?"); - } - - std::vector temp(1); - auto compression_param_estimator = nn_output_comp_vec_factory_->getEstimator(); - compression_param_estimator->accumulate(temp.data(), temp.size()); - auto compression_params = compression_param_estimator->estimate(); - // pretend this history has already been evaluated - cache->nn_output = nn_output_comp_vec_factory_->compress(temp.data(), temp.size(), compression_params.get()); - cache->computed.store(true); - cache->last_used = std::numeric_limits::max(); - empty_history_ = history(h); + setEmptyHistory(); softmax_adapter_->init(session_, tensor_input_map_, tensor_output_map_); - - if (async_) { - background_forwarder_thread_ = std::thread(std::bind(&TFRecurrentLanguageModel::background_forward, this)); - } -} - -TFRecurrentLanguageModel::~TFRecurrentLanguageModel() { - clear_queue(finished_queue_); - - if (async_) { - should_stop_ = true; - background_forwarder_thread_.join(); - } - - size_t total_run_count = 0ul; - size_t total_fwd_hist = 0ul; - double total_run_time = 0.0; - - statistics_ << Core::XmlOpen("fwd-time"); - for (size_t i = 0ul; i < run_count_.size(); i++) { - if (run_count_[i] > 0ul) { - statistics_ << (i + 1) << " " << run_count_[i] << " " << run_time_[i] << "\n"; - total_run_count += run_count_[i]; - total_fwd_hist += (i + 1) * run_count_[i]; - total_run_time += run_time_[i]; - } - } - statistics_ << Core::XmlClose("fwd-time"); - - statistics_ << Core::XmlOpen("fwd-summary"); - statistics_ << Core::XmlOpen("total-run-count") << total_run_count << Core::XmlClose("total-run-count"); - statistics_ << Core::XmlOpen("total-fwd-hist") << total_fwd_hist << Core::XmlClose("total-fwd-hist"); - statistics_ << Core::XmlOpen("total-run-time") + Core::XmlAttribute("unit", "milliseconds") << total_run_time << Core::XmlClose("total-run-time"); - statistics_ << Core::XmlOpen("total-wait-time") + Core::XmlAttribute("unit", "milliseconds") << total_wait_time_ << Core::XmlClose("total-wait-time"); - statistics_ << Core::XmlOpen("total-start-frame-time") + Core::XmlAttribute("unit", "milliseconds") << total_start_frame_time_ << Core::XmlClose("total-start-frame-time"); - statistics_ << Core::XmlOpen("total-expand-hist-time") + Core::XmlAttribute("unit", "milliseconds") << total_expand_hist_time_ << Core::XmlClose("total-expand-hist-time"); - statistics_ << Core::XmlOpen("fwd-times"); - fwd_statistics_.write(statistics_); - statistics_ << Core::XmlClose("fwd-times"); - statistics_ << Core::XmlClose("fwd-summary"); -} - -History TFRecurrentLanguageModel::startHistory() const { - NNHistoryManager* hm = dynamic_cast(historyManager_); - TokenIdSequence ts(1ul, lexicon_mapping_[sentenceBeginToken()->id()]); - HistoryHandle h = hm->get(ts); - ScoresWithContext* cache = const_cast(reinterpret_cast(h)); - cache->parent = empty_history_; - History hist(history(h)); - return hist; -} - -History TFRecurrentLanguageModel::extendedHistory(History const& hist, Token w) const { - return extendedHistory(hist, w->id()); -} - -History TFRecurrentLanguageModel::extendedHistory(History const& hist, Bliss::Token::Id w) const { - return extendHistoryWithOutputIdx(hist, lexicon_mapping_[w]); -} - -History TFRecurrentLanguageModel::reducedHistory(History const& hist, u32 limit) const { - ScoresWithContext const* sc = reinterpret_cast(hist.handle()); - if (not allow_reduced_history_ or sc->history->size() <= limit) { - return hist; - } - History h = startHistory(); - for (u32 w = limit; w > 0; w--) { - h = extendHistoryWithOutputIdx(h, sc->history->at(sc->history->size() - w)); - } - return h; -} - -Score TFRecurrentLanguageModel::score(History const& hist, Token w) const { - ScoresWithContext* sc = const_cast(reinterpret_cast(hist.handle())); - - if (not sc->computed.load()) { - auto start = std::chrono::steady_clock::now(); - if (async_) { - // promise should only be used once - to_fwd_finished_ = std::promise(); - std::future future = to_fwd_finished_.get_future(); - to_fwd_.store(&hist); - future.wait(); - } - else { - forward(&hist); - } - auto end = std::chrono::steady_clock::now(); - double wait_time = std::chrono::duration(end - start).count(); - total_wait_time_ += wait_time; - if (verbose_) { - std::cerr << "wait: " << wait_time << " " << sc->info.numStates << " " << sc->info.bestScoreOffset << std::endl; - } - } - - require(sc->computed.load()); - - size_t output_idx = lexicon_mapping_[w->id()]; - useOutput(*sc, output_idx); - sc->last_used = current_time_; - auto start = std::chrono::steady_clock::now(); - Score score = output_transform_function_(softmax_adapter_->get_score(sc->nn_output, output_idx)); - auto end = std::chrono::steady_clock::now(); - auto duration = std::chrono::duration(end - start); - fwd_statistics_.softmax_output_duration += duration; - fwd_statistics_.total_duration += duration; - return score; -} - -bool TFRecurrentLanguageModel::scoreCached(History const& hist, Token w) const { - ScoresWithContext const* sc = reinterpret_cast(hist.handle()); - return sc->computed.load(); -} - -void TFRecurrentLanguageModel::load() { - loadVocabulary(); -} - -void TFRecurrentLanguageModel::startFrame(Search::TimeframeIndex time) const { - auto timer_start = std::chrono::steady_clock::now(); - - current_time_ = time; - - size_t nn_output_cache_size = 0ul; - size_t state_cache_size = 0ul; - size_t num_histories = 0ul; - - clear_queue(finished_queue_); - - NNHistoryManager* hm = dynamic_cast(historyManager_); - hm->visit([&](HistoryHandle h) { - num_histories += 1ul; - ScoresWithContext* c = const_cast(reinterpret_cast(h)); - bool computed = c->computed.load(); - if (free_memory_ and computed and c->was_expanded and c->info.numStates == 0 and c->last_used < current_time_ - std::min(free_memory_delay_, current_time_)) { - c->nn_output->clear(); - c->computed.store(false); - } - else if (async_ and not computed and not(c->was_expanded and c->info.numStates == 0)) { - fwd_queue_.enqueue(new History(history(h))); - } - if (c->nn_output) { - nn_output_cache_size += c->nn_output->usedMemory(); - } - for (auto const& state_vec : c->state) { - if (state_vec) { - state_cache_size += state_vec->usedMemory(); - } - } - }); - - if (log_memory_ and statistics_.isOpen()) { - statistics_ << Core::XmlOpen("memory-usage") + Core::XmlAttribute("time-frame", current_time_); - statistics_ << Core::XmlOpen("nn-output-cache-size") + Core::XmlAttribute("unit", "MB") << (nn_output_cache_size / (1024. * 1024.)) << Core::XmlClose("nn-output-cache-size"); - statistics_ << Core::XmlOpen("state-cache-size") + Core::XmlAttribute("unit", "MB") << (state_cache_size / (1024. * 1024.)) << Core::XmlClose("state-cache-size"); - statistics_ << Core::XmlOpen("num-histories") << num_histories << Core::XmlClose("num-histories"); - statistics_ << Core::XmlClose("memory-usage"); - } - - auto timer_end = std::chrono::steady_clock::now(); - double start_frame_duration = std::chrono::duration(timer_end - timer_start).count(); - total_start_frame_time_ += start_frame_duration; -} - -void TFRecurrentLanguageModel::setInfo(History const& hist, SearchSpaceInformation const& info) const { - ScoresWithContext* sc = const_cast(reinterpret_cast(hist.handle())); - sc->info = info; - sc->last_info = current_time_; -} - -History TFRecurrentLanguageModel::extendHistoryWithOutputIdx(History const& hist, size_t w) const { - auto timer_start = std::chrono::steady_clock::now(); - NNHistoryManager* hm = dynamic_cast(historyManager_); - ScoresWithContext const* sc = reinterpret_cast(hist.handle()); - TokenIdSequence ts(*sc->history); - ts.push_back(w); - HistoryHandle h = hm->get(ts); - ScoresWithContext* cache = const_cast(reinterpret_cast(h)); - if (cache->parent.handle() == nullptr) { - cache->parent = hist; - ScoresWithContext* parent_cache = const_cast(reinterpret_cast(hist.handle())); - parent_cache->was_expanded = true; - if (async_) { - fwd_queue_.enqueue(new History(history(h))); - } - } - History ext_hist(history(h)); - if (cache->history->size() > history_pruning_threshold_) { - ext_hist = reducedHistory(ext_hist, pruned_history_length_); - } - auto timer_end = std::chrono::steady_clock::now(); - double expand_hist_time = std::chrono::duration(timer_end - timer_start).count(); - total_expand_hist_time_ += expand_hist_time; - return ext_hist; -} - -void TFRecurrentLanguageModel::background_forward() const { - while (not should_stop_) { - forward(to_fwd_.exchange(nullptr)); - } - History const* h = nullptr; - while (fwd_queue_.try_dequeue(h)) { - finished_queue_.enqueue(h); - } - for (History const* h : pending_) { - finished_queue_.enqueue(h); - } - pending_.clear(); } -template -void TFRecurrentLanguageModel::forward(Lm::History const* hist) const { - ScoresWithContext* sc = nullptr; - if (hist != nullptr) { - sc = const_cast(reinterpret_cast(hist->handle())); - } - if (async and sc != nullptr and sc->computed.load()) { // nothing to do (only happens in async case) - to_fwd_finished_.set_value(hist); - return; - } - auto start = std::chrono::steady_clock::now(); - - RequestGraph request_graph; - if (sc != nullptr) { - request_graph.add_cache(const_cast(sc)); - } - - std::vector requests; - std::vector request_histories; // make sure none of the request caches go away while we compute the scores - size_t max_length = 0ul; - - size_t num_pending_requests = pending_.size(); - std::unordered_set handles; // only relevant in async case - handles.reserve(pending_.size()); - std::vector early_requests; - std::vector early_request_histories; // make sure none of the request caches go away while we compute the scores (only relevant in async case) - - if (async) { - auto process_hist = [&](History const* hist) { - ScoresWithContext* c = const_cast(reinterpret_cast(hist->handle())); - ScoresWithContext* parent_c = const_cast(reinterpret_cast(c->parent.handle())); - if (handles.find(hist->handle()) == handles.end() and not c->computed.load() and c != sc and c->parent.handle() != nullptr and c->ref_count > 1 and (not single_step_only_ or parent_c->computed.load())) { - early_requests.emplace_back(c); - early_request_histories.emplace_back(hist); - handles.insert(hist->handle()); - } - else { - finished_queue_.enqueue(hist); - } - }; - - std::for_each(pending_.begin(), pending_.end(), process_hist); - pending_.clear(); - - History const* hist_buf = nullptr; - bool success = false; - bool first = true; - do { - if (first) { - success = fwd_queue_.wait_dequeue_timed(hist_buf, 1000); - } - else { - success = fwd_queue_.try_dequeue(hist_buf); - } - if (success) { - process_hist(hist_buf); - first = false; - } - } while (success); - } - else { - NNHistoryManager* hm = dynamic_cast(historyManager_); - hm->visit([&](HistoryHandle h) { - ScoresWithContext* c = const_cast(reinterpret_cast(h)); - if (not c->computed.load() and c != sc and not(c->was_expanded and c->info.numStates == 0)) { - early_requests.emplace_back(c); - } - }); - } - - size_t num_early_requests = early_requests.size(); - - auto end_early_requests = std::chrono::steady_clock::now(); - - if (async and sc == nullptr and early_requests.empty()) { - // can only happen in async case - return; - } - - // because the scores can be updated while we are sorting we need to store them, so we get a consistent view - std::vector> idxs; - idxs.reserve(early_requests.size()); - for (size_t i = 0ul; i < early_requests.size(); i++) { - idxs.emplace_back(i, early_requests[i]->info.minLabelDistance * 1000 + early_requests[i]->info.bestScoreOffset); - } - std::sort(idxs.begin(), idxs.end(), [](std::pair const& a, std::pair const& b) { - return a.second < b.second; - }); - - for (auto idx : idxs) { - request_graph.add_cache(early_requests[idx.first]); - } - - // we do not need early_requests anymore - early_requests.clear(); - idxs.clear(); - - requests = request_graph.get_requests(); - - // prune requests - if (min_batch_size_ > 0ul and requests.size() > min_batch_size_) { - size_t i = min_batch_size_; - Score ref_score = requests.front().final_cache->info.bestScoreOffset + batch_pruning_threshold_; - if (not Math::isinf(ref_score)) { - while ((i + 1) < requests.size() and requests[i + 1].final_cache->info.bestScoreOffset <= ref_score) { - i += 1ul; - } - requests.resize(i); - } - } - - if (min_batch_size_ > 0ul and opt_batch_size_ > 0ul and requests.size() > opt_batch_size_ + min_batch_size_) { - requests.resize(opt_batch_size_); - } - if (max_batch_size_ > 0ul and requests.size() > max_batch_size_) { - requests.resize(max_batch_size_); - } - - Score worst_score = std::numeric_limits::min(); - for (auto const& r : requests) { - max_length = std::max(max_length, r.length); - worst_score = std::max(worst_score, r.final_cache->info.bestScoreOffset); - } - - auto end_requests = std::chrono::steady_clock::now(); - - // prepare the data in Sprint Datastructures - Math::FastMatrix words(requests.size(), max_length); - Math::FastVector word_lengths(requests.size()); - for (size_t r = 0ul; r < requests.size(); r++) { - auto& history = *(requests[r].final_cache->history); - size_t offset = history.size() - requests[r].length; - for (size_t w = 0u; w < requests[r].length; w++) { - words.at(r, w) = static_cast(history[offset + w]); - } - for (size_t w = requests[r].length; w < max_length; w++) { - words.at(r, w) = 0; - } - word_lengths[r] = requests[r].length; - ScoresWithContext* initial_cache = requests[r].initial_cache; - require(initial_cache != nullptr); - require_eq(state_variables_.size(), initial_cache->state.size()); - } - - bool full_prefix_required = state_manager_->requiresAllParentStates(); - size_t total_prefix_length = 0ul; - size_t total_suffix_length = 0ul; - - std::vector prefix_lengths(requests.size()); - std::vector suffix_lengths(requests.size()); - for (size_t r = 0ul; r < requests.size(); r++) { - prefix_lengths[r] = requests[r].initial_cache->history->size(); - suffix_lengths[r] = requests[r].length; - total_prefix_length += prefix_lengths[r]; - total_suffix_length += suffix_lengths[r]; - } - - std::vector prefix_states(full_prefix_required ? total_prefix_length : requests.size()); - size_t current_offset = 0ul; - for (size_t r = 0ul; r < requests.size(); r++) { - ScoresWithContext* current_cache = requests[r].initial_cache; - if (full_prefix_required) { - size_t prefix_length = prefix_lengths[r]; - for (size_t i = 0ul; i < prefix_length; i++) { - prefix_states[current_offset + prefix_length - i - 1] = ¤t_cache->state; - current_cache = const_cast(reinterpret_cast(current_cache->parent.handle())); - } - current_offset += prefix_length; - } - else { - prefix_states[r] = ¤t_cache->state; - } - } - - auto end_prepare = std::chrono::steady_clock::now(); - - // build tensors + set state variables - std::vector> inputs; - std::vector targets; - state_manager_->mergeStates(state_variables_, prefix_lengths, prefix_states, inputs, targets); - std::vector state_lengths(prefix_lengths.begin(), prefix_lengths.end()); - - if (dump_inputs_) { - std::string out = dump_inputs_prefix_ + "_" + std::to_string(dump_inputs_counter_) + "_state_"; - for (size_t i = 0ul; i < inputs.size(); i++) { - inputs[i].second.save(out + std::to_string(i)); - } - } - - auto end_merge_state = std::chrono::steady_clock::now(); - +void TFRecurrentLanguageModel::setState(std::vector> const& inputs, std::vector const& targets) const { session_.run(inputs, targets); +} - auto end_set_state = std::chrono::steady_clock::now(); - - // run nn-output calculation +void TFRecurrentLanguageModel::extendInputs(std::vector>& inputs, Math::FastMatrix const& words, Math::FastVector const& word_lengths, std::vector const& state_lengths) const { inputs.clear(); auto const& word_info = tensor_input_map_.get_info("word"); inputs.emplace_back(std::make_pair(word_info.tensor_name(), Tensorflow::Tensor::create(words))); @@ -803,111 +126,22 @@ void TFRecurrentLanguageModel::forward(Lm::History const* hist) const { auto const& state_lengths_info = tensor_input_map_.get_info("state-lengths"); inputs.emplace_back(std::make_pair(state_lengths_info.tensor_name(), Tensorflow::Tensor::create(state_lengths))); } - std::vector outputs; - session_.run(inputs, output_tensor_names_, graph_->update_ops(), outputs); - - if (dump_inputs_) { - std::string out = dump_inputs_prefix_ + "_" + std::to_string(dump_inputs_counter_) + "_nn_in_"; - for (size_t i = 0ul; i < inputs.size(); i++) { - inputs[i].second.save(out + std::to_string(i)); - } - out = dump_inputs_prefix_ + "_" + std::to_string(dump_inputs_counter_) + "_nn_out_"; - for (size_t i = 0ul; i < outputs.size(); i++) { - outputs[i].save(out + std::to_string(i)); - } - dump_inputs_counter_ += 1ul; - } - - auto end_nn_output = std::chrono::steady_clock::now(); +} - // store outputs in caches - for (size_t r = 0ul; r < requests.size(); r++) { - ScoresWithContext* cache = requests[r].final_cache; - for (size_t w = requests[r].length; w > 0;) { - --w; - cache->last_used = current_time_; - int num_outputs = outputs[0ul].dimSize(2); - auto compression_param_estimator = nn_output_comp_vec_factory_->getEstimator(); - float const* data = outputs[0ul].data(r, w, 0); - compression_param_estimator->accumulate(data, num_outputs); - auto compression_params = compression_param_estimator->estimate(); - cache->nn_output = nn_output_comp_vec_factory_->compress(data, num_outputs, compression_params.get()); - cache->computed.store(true); - cache = const_cast(reinterpret_cast(cache->parent.handle())); - } - require_eq(cache, requests[r].initial_cache); - } +void TFRecurrentLanguageModel::extendTargets(std::vector& targets) const { +} - auto end_set_nn_output = std::chrono::steady_clock::now(); +void TFRecurrentLanguageModel::getOutputs(std::vector>& inputs, std::vector& outputs, std::vector const& targets) const { + session_.run(inputs, output_tensor_names_, graph_->update_ops(), outputs); +} - // fetch new values of state variables, needs to be done in separate Session::run call (for GPU devices) +std::vector TFRecurrentLanguageModel::fetchStates(std::vector& outputs) const { session_.run({}, read_vars_tensor_names_, {}, outputs); + return outputs; +} - auto end_get_new_state = std::chrono::steady_clock::now(); - - auto split_states = state_manager_->splitStates(state_variables_, suffix_lengths, outputs, *state_comp_vec_factory_); - - size_t output_offset = 0ul; - for (size_t r = 0ul; r < requests.size(); r++) { - ScoresWithContext* current_cache = requests[r].final_cache; - size_t suffix_length = suffix_lengths[r]; - while (suffix_length > 0ul) { - current_cache->state = std::move(split_states[output_offset + suffix_length - 1]); - current_cache = const_cast(reinterpret_cast(current_cache->parent.handle())); - suffix_length -= 1ul; - } - output_offset += suffix_lengths[r]; - } - - auto end_split_state = std::chrono::steady_clock::now(); - - std::chrono::duration duration = end_split_state - end_prepare; - size_t bucket = requests.size() - 1; - run_time_.at(bucket) += duration.count(); - run_count_.at(bucket) += 1ul; - - if (dump_scores_) { - for (auto const& r : requests) { - dump_scores(*r.final_cache, dump_scores_prefix_); - } - } - - if (async) { - for (auto hist : early_request_histories) { - ScoresWithContext* c = const_cast(reinterpret_cast(hist->handle())); - if (c->computed.load() or c->ref_count == 1ul or c->info.numStates == 0) { - finished_queue_.enqueue(hist); - } - else { - pending_.push_back(hist); - } - } - if (sc != nullptr) { - to_fwd_finished_.set_value(hist); - } - } - - auto end = std::chrono::steady_clock::now(); - - TimeStatistics stats; - stats.total_duration = std::chrono::duration(end - start); - stats.early_request_duration = std::chrono::duration(end_early_requests - start); - stats.request_duration = std::chrono::duration(end_requests - end_early_requests); - stats.prepare_duration = std::chrono::duration(end_prepare - end_requests); - stats.merge_state_duration = std::chrono::duration(end_merge_state - end_prepare); - stats.set_state_duration = std::chrono::duration(end_set_state - end_merge_state); - stats.run_nn_output_duration = std::chrono::duration(end_nn_output - end_set_state); - stats.set_nn_output_duration = std::chrono::duration(end_set_nn_output - end_nn_output); - stats.get_new_state_duration = std::chrono::duration(end_get_new_state - end_set_nn_output); - stats.split_state_duration = std::chrono::duration(end_split_state - end_get_new_state); - stats.softmax_output_duration = std::chrono::duration(); - if (verbose_) { - stats.write(std::cerr); - std::cerr << " #pr:" << num_pending_requests - << " #er:" << num_early_requests - << " #r:" << requests.size() << std::endl; - } - fwd_statistics_ += stats; +Score TFRecurrentLanguageModel::transformOutput(Lm::CompressedVectorPtr const& nn_output, size_t index) const { + return softmax_adapter_->get_score(nn_output, index); } } // namespace Lm diff --git a/src/Lm/TFRecurrentLanguageModel.hh b/src/Lm/TFRecurrentLanguageModel.hh index 6f8fd069f..d7659d759 100644 --- a/src/Lm/TFRecurrentLanguageModel.hh +++ b/src/Lm/TFRecurrentLanguageModel.hh @@ -1,166 +1,46 @@ -/** Copyright 2020 RWTH Aachen University. All rights reserved. - * - * Licensed under the RWTH ASR License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ #ifndef _LM_TF_RECURRENT_LANGUAGE_MODEL_HH #define _LM_TF_RECURRENT_LANGUAGE_MODEL_HH -#include -#include -#include - +#include #include #include #include +#include #include -#include -#include "AbstractNNLanguageModel.hh" -#include "CompressedVector.hh" -#include "SearchSpaceAwareLanguageModel.hh" -#include "SoftmaxAdapter.hh" -#include "StateManager.hh" +#include "RecurrentLanguageModel.hh" +#include "TFSoftmaxAdapter.hh" namespace Lm { -class TFRecurrentLanguageModel : public AbstractNNLanguageModel, public SearchSpaceAwareLanguageModel { +class TFRecurrentLanguageModel : public RecurrentLanguageModel { public: - struct TimeStatistics { - std::chrono::duration total_duration; - std::chrono::duration early_request_duration; - std::chrono::duration request_duration; - std::chrono::duration prepare_duration; - std::chrono::duration merge_state_duration; - std::chrono::duration set_state_duration; - std::chrono::duration run_nn_output_duration; - std::chrono::duration set_nn_output_duration; - std::chrono::duration get_new_state_duration; - std::chrono::duration split_state_duration; - std::chrono::duration softmax_output_duration; - - TimeStatistics operator+(TimeStatistics const& other) const; - TimeStatistics& operator+=(TimeStatistics const& other); - - void write(Core::XmlChannel& channel) const; - void write(std::ostream& out) const; - }; - - typedef AbstractNNLanguageModel Precursor; - typedef moodycamel::BlockingReaderWriterQueue HistoryQueue; - - static const Core::ParameterBool paramTransformOuputLog; - static const Core::ParameterBool paramTransformOuputNegate; - static const Core::ParameterInt paramMinBatchSize; - static const Core::ParameterInt paramOptBatchSize; - static const Core::ParameterInt paramMaxBatchSize; - static const Core::ParameterInt paramHistoryPruningThreshold; - static const Core::ParameterInt paramPrunedHistoryLength; - static const Core::ParameterFloat paramBatchPruningThreshold; - static const Core::ParameterBool paramAllowReducedHistory; - static const Core::ParameterBool paramDumpInputs; - static const Core::ParameterString paramDumpInputsPrefix; - static const Core::ParameterBool paramDumpScores; - static const Core::ParameterString paramDumpScoresPrefix; - static const Core::ParameterBool paramLogMemory; - static const Core::ParameterBool paramFreeMemory; - static const Core::ParameterInt paramFreeMemoryDelay; - static const Core::ParameterBool paramAsync; - static const Core::ParameterBool paramSingleStepOnly; - static const Core::ParameterBool paramVerbose; + using Precursor = RecurrentLanguageModel; TFRecurrentLanguageModel(Core::Configuration const& c, Bliss::LexiconRef l); - virtual ~TFRecurrentLanguageModel(); - - virtual History startHistory() const; - virtual History extendedHistory(History const& hist, Token w) const; - virtual History extendedHistory(History const& hist, Bliss::Token::Id w) const; - virtual History reducedHistory(History const& hist, u32 limit) const; - virtual Score score(History const& hist, Token w) const; - virtual bool scoreCached(History const& hist, Token w) const; - - virtual void startFrame(Search::TimeframeIndex time) const; - virtual void setInfo(History const& hist, SearchSpaceInformation const& info) const; + virtual ~TFRecurrentLanguageModel() {} protected: - virtual void load(); + virtual void setState(std::vector> const& inputs, std::vector const& targets) const; + virtual void extendInputs(std::vector>& inputs, Math::FastMatrix const& words, Math::FastVector const& word_lengths, std::vector const& state_lengths) const; + virtual void extendTargets(std::vector& targets) const; + virtual void getOutputs(std::vector>& inputs, std::vector& outputs, std::vector const& targets) const; + virtual std::vector fetchStates(std::vector& outputs) const; -private: - bool transform_output_log_; - bool transform_output_negate_; - std::function output_transform_function_; - size_t min_batch_size_; - size_t opt_batch_size_; - size_t max_batch_size_; - size_t history_pruning_threshold_; - size_t pruned_history_length_; - Score batch_pruning_threshold_; - bool allow_reduced_history_; - bool dump_inputs_; - std::string dump_inputs_prefix_; - bool dump_scores_; - std::string dump_scores_prefix_; - bool log_memory_; - bool free_memory_; - Search::TimeframeIndex free_memory_delay_; - bool async_; - bool single_step_only_; - bool verbose_; + virtual Score transformOutput(Lm::CompressedVectorPtr const& nn_output, size_t index) const; +private: mutable Tensorflow::Session session_; std::unique_ptr loader_; std::unique_ptr graph_; Tensorflow::TensorInputMap tensor_input_map_; Tensorflow::TensorOutputMap tensor_output_map_; - CompressedVectorFactoryPtr state_comp_vec_factory_; - CompressedVectorFactoryPtr nn_output_comp_vec_factory_; - - std::vector state_variables_; - std::unique_ptr state_manager_; - std::unique_ptr softmax_adapter_; + std::unique_ptr softmax_adapter_; std::vector initializer_tensor_names_; std::vector output_tensor_names_; std::vector read_vars_tensor_names_; - - History empty_history_; // a history used to provide the previous (all zero) state to the first real history (1 sentence-begin token) - - mutable Core::XmlChannel statistics_; - mutable Search::TimeframeIndex current_time_; - mutable std::vector run_time_; - mutable std::vector run_count_; - mutable double total_wait_time_; - mutable double total_start_frame_time_; - mutable double total_expand_hist_time_; - mutable TimeStatistics fwd_statistics_; - mutable size_t dump_inputs_counter_; - - // members for async forwarding - std::thread background_forwarder_thread_; - bool should_stop_; - - mutable std::atomic to_fwd_; - mutable std::promise to_fwd_finished_; - - mutable std::vector pending_; - mutable HistoryQueue fwd_queue_; - mutable HistoryQueue finished_queue_; - - History extendHistoryWithOutputIdx(History const& hist, size_t w) const; - - void background_forward() const; - template - void forward(Lm::History const* hist) const; }; } // namespace Lm diff --git a/src/Lm/SoftmaxAdapter.hh b/src/Lm/TFSoftmaxAdapter.hh similarity index 82% rename from src/Lm/SoftmaxAdapter.hh rename to src/Lm/TFSoftmaxAdapter.hh index c5e1adb45..989ee4b4c 100644 --- a/src/Lm/SoftmaxAdapter.hh +++ b/src/Lm/TFSoftmaxAdapter.hh @@ -25,12 +25,12 @@ namespace Lm { using Score = float; -class SoftmaxAdapter : public Core::Component { +class TFSoftmaxAdapter : public Core::Component { public: using Precursor = Core::Component; - SoftmaxAdapter(Core::Configuration const& config); - virtual ~SoftmaxAdapter() = default; + TFSoftmaxAdapter(Core::Configuration const& config); + virtual ~TFSoftmaxAdapter() = default; virtual void init(Tensorflow::Session& session, Tensorflow::TensorInputMap const& input_map, Tensorflow::TensorOutputMap const& output_map) = 0; virtual Score get_score(Lm::CompressedVectorPtr const& nn_out, size_t output_idx) = 0; @@ -41,11 +41,11 @@ private: // inline implementations -inline SoftmaxAdapter::SoftmaxAdapter(Core::Configuration const& config) +inline TFSoftmaxAdapter::TFSoftmaxAdapter(Core::Configuration const& config) : Precursor(config) { } -inline std::vector SoftmaxAdapter::get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs) { +inline std::vector TFSoftmaxAdapter::get_scores(Lm::CompressedVectorPtr const& nn_out, std::vector const& output_idxs) { std::vector scores; scores.reserve(output_idxs.size()); for (size_t output_idx : output_idxs) { diff --git a/src/Lm/TransformerStateManager.cc b/src/Lm/TFTransformerStateManager.cc similarity index 61% rename from src/Lm/TransformerStateManager.cc rename to src/Lm/TFTransformerStateManager.cc index 2bce463fd..cdd25e727 100644 --- a/src/Lm/TransformerStateManager.cc +++ b/src/Lm/TFTransformerStateManager.cc @@ -12,7 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "TransformerStateManager.hh" +#include "TFTransformerStateManager.hh" #include "FixedQuantizationCompressedVectorFactory.hh" @@ -72,28 +72,28 @@ void uncompress(Lm::CompressedVector const* vec, int8_t* dst, B const& b) namespace Lm { -/* ----------------------------------- TransformerStateManager ---------------------------------- */ +/* ----------------------------------- TFTransformerStateManager ---------------------------------- */ template -const Core::ParameterInt TransformerStateManager::paramMaxHistoryLength("max-history", - "maximum length of the history to feed to the transformer", - std::numeric_limits::max(), - 0); -template const Core::ParameterInt TransformerStateManager::paramMaxHistoryLength; -template const Core::ParameterInt TransformerStateManager::paramMaxHistoryLength; -template const Core::ParameterInt TransformerStateManager::paramMaxHistoryLength; +const Core::ParameterInt TFTransformerStateManager::paramMaxHistoryLength("max-history", + "maximum length of the history to feed to the transformer", + std::numeric_limits::max(), + 0); +template const Core::ParameterInt TFTransformerStateManager::paramMaxHistoryLength; +template const Core::ParameterInt TFTransformerStateManager::paramMaxHistoryLength; +template const Core::ParameterInt TFTransformerStateManager::paramMaxHistoryLength; template -const Core::ParameterBool TransformerStateManager::paramAlwaysIncludeFirstTokenState("always-include-first-token-state", - "wether to always include the state of the first token, even if history is restricted by max-history", - false); -template const Core::ParameterBool TransformerStateManager::paramAlwaysIncludeFirstTokenState; -template const Core::ParameterBool TransformerStateManager::paramAlwaysIncludeFirstTokenState; -template const Core::ParameterBool TransformerStateManager::paramAlwaysIncludeFirstTokenState; +const Core::ParameterBool TFTransformerStateManager::paramAlwaysIncludeFirstTokenState("always-include-first-token-state", + "wether to always include the state of the first token, even if history is restricted by max-history", + false); +template const Core::ParameterBool TFTransformerStateManager::paramAlwaysIncludeFirstTokenState; +template const Core::ParameterBool TFTransformerStateManager::paramAlwaysIncludeFirstTokenState; +template const Core::ParameterBool TFTransformerStateManager::paramAlwaysIncludeFirstTokenState; template -typename TransformerStateManager::HistoryState TransformerStateManager::initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory) { - TransformerStateManager::HistoryState result; +typename TFTransformerStateManager::HistoryState TFTransformerStateManager::initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory) { + TFTransformerStateManager::HistoryState result; result.reserve(vars.size()); std::vector vec(0, 0.0f); @@ -107,16 +107,16 @@ typename TransformerStateManager::HistoryState TransformerStateManager::in return result; } -template typename TransformerStateManager::HistoryState TransformerStateManager::initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory); -template typename TransformerStateManager::HistoryState TransformerStateManager::initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory); -template typename TransformerStateManager::HistoryState TransformerStateManager::initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory); +template typename TFTransformerStateManager::HistoryState TFTransformerStateManager::initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory); +template typename TFTransformerStateManager::HistoryState TFTransformerStateManager::initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory); +template typename TFTransformerStateManager::HistoryState TFTransformerStateManager::initialState(StateVariables const& vars, CompressedVectorFactory const& vector_factory); template -void TransformerStateManager::mergeStates(StateVariables const& vars, - std::vector& prefix_lengths, - std::vector const& prefix_states, - FeedDict& feed_dict, - TargetList& targets) { +void TFTransformerStateManager::mergeStates(StateVariables const& vars, + std::vector& prefix_lengths, + std::vector const& prefix_states, + FeedDict& feed_dict, + TargetList& targets) { std::vector original_prefix_lengths(prefix_lengths); size_t max_prefix = 0ul; @@ -180,28 +180,28 @@ void TransformerStateManager::mergeStates(StateVariables const& } } -template void TransformerStateManager::mergeStates(StateVariables const& vars, - std::vector& prefix_lengths, - std::vector const& prefix_states, - FeedDict& feed_dict, - TargetList& targets); -template void TransformerStateManager::mergeStates(StateVariables const& vars, +template void TFTransformerStateManager::mergeStates(StateVariables const& vars, std::vector& prefix_lengths, std::vector const& prefix_states, FeedDict& feed_dict, TargetList& targets); -template void TransformerStateManager::mergeStates(StateVariables const& vars, - std::vector& prefix_lengths, - std::vector const& prefix_states, - FeedDict& feed_dict, - TargetList& targets); +template void TFTransformerStateManager::mergeStates(StateVariables const& vars, + std::vector& prefix_lengths, + std::vector const& prefix_states, + FeedDict& feed_dict, + TargetList& targets); +template void TFTransformerStateManager::mergeStates(StateVariables const& vars, + std::vector& prefix_lengths, + std::vector const& prefix_states, + FeedDict& feed_dict, + TargetList& targets); template -std::vector::HistoryState> - TransformerStateManager::splitStates(StateVariables const& vars, - std::vector& suffix_lengths, - std::vector const& state_tensors, - CompressedVectorFactory const& vector_factory) { +std::vector::HistoryState> + TFTransformerStateManager::splitStates(StateVariables const& vars, + std::vector& suffix_lengths, + std::vector const& state_tensors, + CompressedVectorFactory const& vector_factory) { require_eq(vars.size(), state_tensors.size()); size_t max_suffix = *std::max_element(suffix_lengths.begin(), suffix_lengths.end()); @@ -258,78 +258,78 @@ std::vector::HistoryState> return result; } -template std::vector::HistoryState> - TransformerStateManager::splitStates(StateVariables const& vars, - std::vector& suffix_lengths, - std::vector const& state_tensors, - CompressedVectorFactory const& vector_factory); -template std::vector::HistoryState> - TransformerStateManager::splitStates(StateVariables const& vars, +template std::vector::HistoryState> + TFTransformerStateManager::splitStates(StateVariables const& vars, std::vector& suffix_lengths, std::vector const& state_tensors, CompressedVectorFactory const& vector_factory); -template std::vector::HistoryState> - TransformerStateManager::splitStates(StateVariables const& vars, - std::vector& suffix_lengths, - std::vector const& state_tensors, - CompressedVectorFactory const& vector_factory); - -/* ----------------------------------- TransformerStateManagerWithCommonPrefix ---------------------------------- */ +template std::vector::HistoryState> + TFTransformerStateManager::splitStates(StateVariables const& vars, + std::vector& suffix_lengths, + std::vector const& state_tensors, + CompressedVectorFactory const& vector_factory); +template std::vector::HistoryState> + TFTransformerStateManager::splitStates(StateVariables const& vars, + std::vector& suffix_lengths, + std::vector const& state_tensors, + CompressedVectorFactory const& vector_factory); + +/* ----------------------------------- TFTransformerStateManagerWithCommonPrefix ---------------------------------- */ template -const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramVarName("var-name", "the name of the original state variable", ""); -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramVarName; -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramVarName; -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramVarName; +const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramVarName("var-name", "the name of the original state variable", ""); +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramVarName; +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramVarName; +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramVarName; template -const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramCommonPrefixInitialValue("common-prefix-initial-value", "the name the initial-value of the corresponding common-prefix variable", ""); -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramCommonPrefixInitialValue; -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramCommonPrefixInitialValue; -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramCommonPrefixInitialValue; +const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramCommonPrefixInitialValue("common-prefix-initial-value", "the name the initial-value of the corresponding common-prefix variable", ""); +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramCommonPrefixInitialValue; +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramCommonPrefixInitialValue; +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramCommonPrefixInitialValue; template -const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramCommonPrefixInitializer("common-prefix-initializer", "the name of the initializer of the corresponding common-prefix variable", ""); -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramCommonPrefixInitializer; -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramCommonPrefixInitializer; -template const Core::ParameterString TransformerStateManagerWithCommonPrefix::paramCommonPrefixInitializer; +const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramCommonPrefixInitializer("common-prefix-initializer", "the name of the initializer of the corresponding common-prefix variable", ""); +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramCommonPrefixInitializer; +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramCommonPrefixInitializer; +template const Core::ParameterString TFTransformerStateManagerWithCommonPrefix::paramCommonPrefixInitializer; template -const Core::ParameterBool TransformerStateManagerWithCommonPrefix::paramCachePrefix("cache-prefix", "wether to reuse the prefix if it's the same", false); -template const Core::ParameterBool TransformerStateManagerWithCommonPrefix::paramCachePrefix; -template const Core::ParameterBool TransformerStateManagerWithCommonPrefix::paramCachePrefix; -template const Core::ParameterBool TransformerStateManagerWithCommonPrefix::paramCachePrefix; +const Core::ParameterBool TFTransformerStateManagerWithCommonPrefix::paramCachePrefix("cache-prefix", "wether to reuse the prefix if it's the same", false); +template const Core::ParameterBool TFTransformerStateManagerWithCommonPrefix::paramCachePrefix; +template const Core::ParameterBool TFTransformerStateManagerWithCommonPrefix::paramCachePrefix; +template const Core::ParameterBool TFTransformerStateManagerWithCommonPrefix::paramCachePrefix; template -const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMinBatchSize("min-batch-size", - "for batches smaller than the given size we set the common-prefix length to 0", - 2, 0); -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMinBatchSize; -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMinBatchSize; -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMinBatchSize; +const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMinBatchSize("min-batch-size", + "for batches smaller than the given size we set the common-prefix length to 0", + 2, 0); +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMinBatchSize; +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMinBatchSize; +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMinBatchSize; template -const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMinCommonPrefixLength("min-common-prefix-length", - "if the common-prefix length is smaller than this value, set it to 0", - 1, 0); -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMinCommonPrefixLength; -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMinCommonPrefixLength; -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMinCommonPrefixLength; +const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMinCommonPrefixLength("min-common-prefix-length", + "if the common-prefix length is smaller than this value, set it to 0", + 1, 0); +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMinCommonPrefixLength; +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMinCommonPrefixLength; +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMinCommonPrefixLength; template -const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMaxCommonPrefixLength("max-common-prefix-length", - "Truncate the common prefix to this length. Observes always-include-first-token-state.", - std::numeric_limits::max(), 0); -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMaxCommonPrefixLength; -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMaxCommonPrefixLength; -template const Core::ParameterInt TransformerStateManagerWithCommonPrefix::paramMaxCommonPrefixLength; +const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMaxCommonPrefixLength("max-common-prefix-length", + "Truncate the common prefix to this length. Observes always-include-first-token-state.", + std::numeric_limits::max(), 0); +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMaxCommonPrefixLength; +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMaxCommonPrefixLength; +template const Core::ParameterInt TFTransformerStateManagerWithCommonPrefix::paramMaxCommonPrefixLength; template -void TransformerStateManagerWithCommonPrefix::mergeStates(typename Precursor::StateVariables const& vars, - std::vector& prefix_lengths, - std::vector const& prefix_states, - typename Precursor::FeedDict& feed_dict, - typename Precursor::TargetList& targets) { +void TFTransformerStateManagerWithCommonPrefix::mergeStates(typename Precursor::StateVariables const& vars, + std::vector& prefix_lengths, + std::vector const& prefix_states, + typename Precursor::FeedDict& feed_dict, + typename Precursor::TargetList& targets) { std::vector original_prefix_lengths(prefix_lengths); std::vector batch_offsets; batch_offsets.reserve(prefix_lengths.size() + 1ul); @@ -464,20 +464,20 @@ void TransformerStateManagerWithCommonPrefix::mergeStates(typename Precursor: } } -template void TransformerStateManagerWithCommonPrefix::mergeStates(StateVariables const& vars, - std::vector& prefix_lengths, - std::vector const& prefix_states, - FeedDict& feed_dict, - TargetList& targets); -template void TransformerStateManagerWithCommonPrefix::mergeStates(StateVariables const& vars, +template void TFTransformerStateManagerWithCommonPrefix::mergeStates(StateVariables const& vars, std::vector& prefix_lengths, std::vector const& prefix_states, FeedDict& feed_dict, TargetList& targets); -template void TransformerStateManagerWithCommonPrefix::mergeStates(StateVariables const& vars, - std::vector& prefix_lengths, - std::vector const& prefix_states, - FeedDict& feed_dict, - TargetList& targets); +template void TFTransformerStateManagerWithCommonPrefix::mergeStates(StateVariables const& vars, + std::vector& prefix_lengths, + std::vector const& prefix_states, + FeedDict& feed_dict, + TargetList& targets); +template void TFTransformerStateManagerWithCommonPrefix::mergeStates(StateVariables const& vars, + std::vector& prefix_lengths, + std::vector const& prefix_states, + FeedDict& feed_dict, + TargetList& targets); } // namespace Lm diff --git a/src/Lm/TransformerStateManager.hh b/src/Lm/TFTransformerStateManager.hh similarity index 82% rename from src/Lm/TransformerStateManager.hh rename to src/Lm/TFTransformerStateManager.hh index cfafbc806..5fac61dd6 100644 --- a/src/Lm/TransformerStateManager.hh +++ b/src/Lm/TFTransformerStateManager.hh @@ -15,22 +15,25 @@ #ifndef _LM_TRANSFORMER_STATE_MANAGER_HH #define _LM_TRANSFORMER_STATE_MANAGER_HH -#include "StateManager.hh" +#include "AbstractStateManager.hh" + +#include +#include #include namespace Lm { template -class TransformerStateManager : public StateManager { +class TFTransformerStateManager : public AbstractStateManager { public: - using Precursor = StateManager; + using Precursor = AbstractStateManager; static const Core::ParameterInt paramMaxHistoryLength; static const Core::ParameterBool paramAlwaysIncludeFirstTokenState; - TransformerStateManager(Core::Configuration const& config); - virtual ~TransformerStateManager() = default; + TFTransformerStateManager(Core::Configuration const& config); + virtual ~TFTransformerStateManager() = default; virtual bool requiresAllParentStates() const; @@ -51,9 +54,9 @@ protected: }; template -class TransformerStateManagerWithCommonPrefix : public TransformerStateManager { +class TFTransformerStateManagerWithCommonPrefix : public TFTransformerStateManager { public: - using Precursor = TransformerStateManager; + using Precursor = TFTransformerStateManager; static const Core::ParameterString paramVarName; static const Core::ParameterString paramCommonPrefixInitialValue; @@ -63,8 +66,8 @@ public: static const Core::ParameterInt paramMinCommonPrefixLength; static const Core::ParameterInt paramMaxCommonPrefixLength; - TransformerStateManagerWithCommonPrefix(Core::Configuration const& config); - virtual ~TransformerStateManagerWithCommonPrefix() = default; + TFTransformerStateManagerWithCommonPrefix(Core::Configuration const& config); + virtual ~TFTransformerStateManagerWithCommonPrefix() = default; virtual void mergeStates(typename Precursor::StateVariables const& vars, std::vector& prefix_lengths, @@ -86,19 +89,19 @@ protected: // inline implementations template -inline TransformerStateManager::TransformerStateManager(Core::Configuration const& config) +inline TFTransformerStateManager::TFTransformerStateManager(Core::Configuration const& config) : Precursor(config), maxHistory_(paramMaxHistoryLength(config)), alwaysIncludeFirstTokenState_(paramAlwaysIncludeFirstTokenState(config)) { } template -inline bool TransformerStateManager::requiresAllParentStates() const { +inline bool TFTransformerStateManager::requiresAllParentStates() const { return true; } template -inline TransformerStateManagerWithCommonPrefix::TransformerStateManagerWithCommonPrefix(Core::Configuration const& config) +inline TFTransformerStateManagerWithCommonPrefix::TFTransformerStateManagerWithCommonPrefix(Core::Configuration const& config) : Precursor(config), cachePrefix_(paramCachePrefix(config)), minBatchSize_(paramMinBatchSize(config)), diff --git a/src/Lm/Zerogram.cc b/src/Lm/Zerogram.cc index b8bb67118..441e71283 100644 --- a/src/Lm/Zerogram.cc +++ b/src/Lm/Zerogram.cc @@ -19,7 +19,8 @@ using namespace Lm; Zerogram::Zerogram(const Core::Configuration& c, Bliss::LexiconRef l) - : Core::Component(c), LanguageModel(c, l) { + : Core::Component(c), + LanguageModel(c, l) { historyManager_ = this; log("Zerogram LM probability is 1/%d", lexicon()->nSyntacticTokens()); score_ = ::log(f64(lexicon()->nSyntacticTokens())); diff --git a/src/Lm/Zerogram.hh b/src/Lm/Zerogram.hh index 2a62a2578..c7ee95fe8 100644 --- a/src/Lm/Zerogram.hh +++ b/src/Lm/Zerogram.hh @@ -30,13 +30,17 @@ private: public: Zerogram(const Core::Configuration& c, Bliss::LexiconRef); + virtual Fsa::ConstAutomatonRef getFsa() const; - virtual History startHistory() const { + + virtual History startHistory() const { return history(0); } + virtual History extendedHistory(const History& h, Token) const { return h; } + virtual Score score(const History&, Token) const { return score_; } diff --git a/src/Lm/patch-philips-lm b/src/Lm/patch-philips-lm deleted file mode 100755 index e020ef92f..000000000 --- a/src/Lm/patch-philips-lm +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -# $Id$ - -# This utility program patches an original Philips language model file -# to use it with Sprint/Lm -# -# The following adaptations are performed: -# - Certain keywords in the class LM are changed to their non-class -# equivalents. -# - German TeX-style umlauts are converted to UTF-8. -# - Should any non-ASCII characters occur, they are interpreted as -# ISO-8859-1 and consequently converted to UTF-8. - -if [ $# -eq 2 ]; then - orig=$1 - out=$2 -else - echo "Usage: $0 " - # for i in inp: - # txt+= "" % i - # txt += "
%s
" - # inp = txt - #else: - # inp = "" - - - label = "%s" % (inp, name) - for i in range(len(atts)): - key = atts.item(i).name - value = att(n, key) -# value = clean(value) - if key=="name": continue - if key=="filter": - if value[-5:] == ".flow": - value="%s" % value - elif value[0]!="$": - value="%s" % value - if value[0]=="$": - value="%s" % value - - label += "" % (key, value) - - label += "
%s
%s%s
" - name = clean(name) - - print >> out, """%s [shape=plaintext + if len(argv) < 2: + usage(argv[0]) + + dom = parse(argv[1]) + nodes = dom.getElementsByTagName("node") + links = dom.getElementsByTagName("link") + inputs = dom.getElementsByTagName("in") + outputs = dom.getElementsByTagName("out") + params = dom.getElementsByTagName("param") + network = dom.getElementsByTagName("network")[0] + + out, outname = tempfile.mkstemp() + out = open(outname, "w") + print("digraph flow {", file=out) + + att = lambda n, a: n.getAttribute(a) + + txt = "" + netname = att(network, "name") + if not netname: + netname = "network" + + for i in inputs + outputs: + i = att(i, "name") + print( + """%s [shape=plaintext + label = "%s:%s"];""" + % (i, netname, i), + file=out, + ) + + for n in nodes: + name = att(n, "name") + atts = n.attributes + inp = "" + + label = ( + "%s" + % (inp, name) + ) + for i in range(len(atts)): + key = atts.item(i).name + value = att(n, key) + # value = clean(value) + if key == "name": + continue + if key == "filter": + if value[-5:] == ".flow": + value = "%s" % value + elif value[0] != "$": + value = "%s" % value + if value[0] == "$": + value = "%s" % value + + label += "" % ( + key, + value, + ) + + label += "
%s
%s%s
" + name = clean(name) + + print( + """%s [shape=plaintext label = <%s> - ];""" % (name, label) + ];""" + % (name, label), + file=out, + ) - for l in links: + for l in links: - fr = att(l, "from") - to = att(l, "to") + fr = att(l, "from") + to = att(l, "to") - fr = clean(fr) - to = clean(to) + fr = clean(fr) + to = clean(to) - if fr.startswith(netname+":"): fr=fr.split(":")[1] - if to.startswith(netname+":"): to=to.split(":")[1] + if fr.startswith(netname + ":"): + fr = fr.split(":")[1] + if to.startswith(netname + ":"): + to = to.split(":")[1] - label = "" - if ":" in fr or ":" in to: - port1 = "" - port2 = "" - if ":" in fr: port1 = fr.split(":")[1] - if ":" in to: port2 = to.split(":")[1] - label = ' [label="%s->%s", fontname="svg"]' % (port1, port2) + label = "" + if ":" in fr or ":" in to: + port1 = "" + port2 = "" + if ":" in fr: + port1 = fr.split(":")[1] + if ":" in to: + port2 = to.split(":")[1] + label = ' [label="%s->%s", fontname="svg"]' % (port1, port2) - print >> out, "%s -> %s %s;" % (fr, to, label) + print("%s -> %s %s;" % (fr, to, label), file=out) - print >> out, "}" - #print "dot -Tpng < %s > %s.png; eog %s.png " % (outname, outname, outname) - out.close() + print("}", file=out) + # print "dot -Tpng < %s > %s.png; eog %s.png " % (outname, outname, outname) + out.close() - if len(argv) == 3: - pic = argv[2] - else: - pic = "%s.png" % outname + if len(argv) == 3: + pic = argv[2] + else: + pic = "%s.png" % outname + + os.system("dot -Tpng < %s > %s" % (outname, pic)) + print("Plot written to %s" % pic) - os.system("dot -Tpng < %s > %s" % (outname, pic)) - print "Plot written to %s" % pic def usage(prog): - print "USAGE: %s file.flow [plot.png]" % prog - sys.exit(-1) + print("USAGE: %s file.flow [plot.png]" % prog) + sys.exit(-1) + if __name__ == "__main__": - main(sys.argv) + main(sys.argv) diff --git a/src/Tools/LatticeProcessor/LatticeProcessor.cc b/src/Tools/LatticeProcessor/LatticeProcessor.cc index 33f1f653c..241f7b6ca 100644 --- a/src/Tools/LatticeProcessor/LatticeProcessor.cc +++ b/src/Tools/LatticeProcessor/LatticeProcessor.cc @@ -361,13 +361,14 @@ std::string LatticeProcessor::getApplicationDescription() const { } std::string LatticeProcessor::getParameterDescription() const { - std::string tmp = "\n" - "supported actions:\n" - "\n" - "\t\"not-given\": do nothing\n" - "\t[*.selection]\n" - "\t# nothing to configurate\n" - "\t\n"; + std::string tmp = "\n" + "supported actions:\n" + "\n" + "\t\"not-given\": do nothing\n" + "\t[*.selection]\n" + "\t# nothing to configurate\n" + "\t\n"; + std::vector actions = processorFactory_.identifiers(); for (std::vector::const_iterator a = actions.begin(); a != actions.end(); ++a) { std::string name = choiceAction[*a]; @@ -468,8 +469,7 @@ APPLICATION(LatticeProcessor) /** * Parse the actions and selections statements in the configuration */ -void LatticeProcessor::parseActionsSelections( - std::vector& actions, std::vector& selections) { +void LatticeProcessor::parseActionsSelections(std::vector& actions, std::vector& selections) { require(actions.empty() && selections.empty()); std::vector _actions = paramActions(config); selections = paramSelections(config); diff --git a/src/Tools/LatticeProcessor/LatticeProcessor.hh b/src/Tools/LatticeProcessor/LatticeProcessor.hh index fd911dc29..fe90ef8b8 100644 --- a/src/Tools/LatticeProcessor/LatticeProcessor.hh +++ b/src/Tools/LatticeProcessor/LatticeProcessor.hh @@ -115,7 +115,9 @@ private: static const Core::Choice choiceAction; static const Core::ParameterStringVector paramActions; static const Core::ParameterStringVector paramSelections; - enum CorpusType { bliss }; + enum CorpusType { + bliss + }; static const Core::Choice choiceCorpusType; static const Core::ParameterChoice paramCorpusType; enum ApplicationType { diff --git a/src/Tools/LatticeProcessor/Makefile b/src/Tools/LatticeProcessor/Makefile index a360a95b1..3fa2009e3 100644 --- a/src/Tools/LatticeProcessor/Makefile +++ b/src/Tools/LatticeProcessor/Makefile @@ -25,8 +25,8 @@ LATTICE_PROCESSOR_O = $(OBJDIR)/LatticeProcessor.o \ ../../Core/libSprintCore.$(a) \ ../../Fsa/libSprintFsa.$(a) \ ../../Flf/FlfCore/libSprintFlfCore.$(a) - -LATTICE_PROCESSOR_O += $(subst src,../..,$(LIBS_SEARCH)) + +LATTICE_PROCESSOR_O += $(subst src,../..,$(LIBS_SEARCH)) ifdef MODULE_CART LATTICE_PROCESSOR_O += ../../Cart/libSprintCart.$(a) @@ -48,6 +48,9 @@ LATTICE_PROCESSOR_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) LDFLAGS += $(TF_LDFLAGS) endif +ifdef MODULE_LM_ONNX +LATTICE_PROCESSOR_O += ../../Onnx/libSprintOnnx.$(a) +endif # ----------------------------------------------------------------------------- build: $(TARGETS) diff --git a/src/Tools/LibRASR/LibRASR.cc b/src/Tools/LibRASR/LibRASR.cc index ee676611b..76c6694da 100644 --- a/src/Tools/LibRASR/LibRASR.cc +++ b/src/Tools/LibRASR/LibRASR.cc @@ -8,14 +8,14 @@ #include #include #include -#include -#include #ifdef MODULE_NN #include #endif #ifdef MODULE_ONNX #include #endif +#include +#include #ifdef MODULE_TENSORFLOW #include #endif diff --git a/src/Tools/Lm/LmUtilityTool.cc b/src/Tools/Lm/LmUtilityTool.cc index e8d320ff8..c1b7ce6e7 100644 --- a/src/Tools/Lm/LmUtilityTool.cc +++ b/src/Tools/Lm/LmUtilityTool.cc @@ -136,15 +136,37 @@ void LmUtilityTool::computePerplexityFromTextFile() { Core::TextInputStream tis(new Core::CompressedInputStream(paramFile(config))); Core::TextOutputStream out; + log("reading text from '%s'", paramFile(config).c_str()); tis.setEncoding(paramEncoding(config)); + out.setEncoding(paramEncoding(config)); std::string out_file = paramScoreFile(config); if (not out_file.empty()) { out.open(out_file); + log("saving scores to '%s'", out_file.c_str()); } std::vector requests; size_t num_tokens = 0; + size_t num_lines = 0; + size_t num_unks = 0; + size_t num_eos = 0; Lm::Score corpus_score = 0.0; + Lm::Score eos_scores = 0.0; + Lm::Score unks_scores = 0.0; + + Bliss::Lemma const* eos_lemma = lexicon->specialLemma("sentence-boundary"); + if (eos_lemma == nullptr) { + eos_lemma = lexicon->specialLemma("sentence-end"); + } + require_ne(eos_lemma, nullptr); + + Bliss::Lemma const* sos_lemma = lexicon->specialLemma("sentence-begin"); + if (sos_lemma == nullptr) { + warning("sentence-begin not found, using unigram probability instead\n"); + } + + Bliss::Lemma const* unk_lemma = lexicon->specialLemma("unknown"); + require_ne(unk_lemma, nullptr); do { std::string line; @@ -165,17 +187,25 @@ void LmUtilityTool::computePerplexityFromTextFile() { h = lm->extendedHistory(h, t); } } - Bliss::Lemma const* lemma = lexicon->specialLemma("sentence-end"); - auto const tokens = lemma->syntacticTokenSequence(); + auto const tokens = eos_lemma->syntacticTokenSequence(); for (auto const& t : tokens) { - requests.emplace_back(LMRequest({"\\n", lemma, t, h, 0.0f})); + requests.emplace_back(LMRequest({"\\n", eos_lemma, t, h, 0.0f})); h = lm->extendedHistory(h, t); } + ++num_lines; } if (not tis.good() or requests.size() >= batch_size) { computeAllScores(requests, lm, renormalize); for (auto const& r : requests) { + if (r.lemma == eos_lemma) { + eos_scores += r.score; + num_eos += 1ul; + } + if (r.lemma == unk_lemma) { + unks_scores += r.score; + num_unks += 1ul; + } corpus_score += r.score; num_tokens += 1ul; if (out.good()) { @@ -186,9 +216,18 @@ void LmUtilityTool::computePerplexityFromTextFile() { } } while (tis.good()); - Lm::Score ppl = std::exp(corpus_score / num_tokens); + Lm::Score ppl = std::exp(corpus_score / num_tokens); + Lm::Score ppl_wo_eos = std::exp((corpus_score - eos_scores) / (num_tokens - num_eos)); + Lm::Score ppl_wo_unks = std::exp((corpus_score - unks_scores) / (num_tokens - num_unks)); + Lm::Score ppl_wo_eos_wo_unks = std::exp((corpus_score - unks_scores - eos_scores) / (num_tokens - num_unks - num_eos)); log() << Core::XmlOpen("corpus-score") << corpus_score << Core::XmlClose("corpus-score") << Core::XmlOpen("num-tokens") << num_tokens << Core::XmlClose("num-tokens") - << Core::XmlOpen("perplexity") << ppl << Core::XmlClose("perplexity"); + << Core::XmlOpen("num-unks") << num_unks << Core::XmlClose("num-unks") + << Core::XmlOpen("unk-ratio") << static_cast(num_unks) / static_cast(num_tokens) << Core::XmlClose("unk-ratio") + << Core::XmlOpen("num-lines") << num_lines << Core::XmlClose("num-lines") + << Core::XmlOpen("perplexity") << ppl << Core::XmlClose("perplexity") + << Core::XmlOpen("perplexity-without-eos") << ppl_wo_eos << Core::XmlClose("perplexity-without-eos") + << Core::XmlOpen("perplexity-without-unknowns") << ppl_wo_unks << Core::XmlClose("perplexity-without-unknowns") + << Core::XmlOpen("perplexity-without-eos-without-unknowns") << ppl_wo_eos_wo_unks << Core::XmlClose("perplexity-without-eos-without-unknowns"); } diff --git a/src/Tools/Lm/Makefile b/src/Tools/Lm/Makefile index f4982659a..9a52b6c87 100644 --- a/src/Tools/Lm/Makefile +++ b/src/Tools/Lm/Makefile @@ -52,6 +52,9 @@ LM_UTIL_TOOL_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) LDFLAGS += $(TF_LDFLAGS) endif +ifdef MODULE_LM_ONNX +LM_UTIL_TOOL_O += ../../Onnx/libSprintOnnx.$(a) +endif # ----------------------------------------------------------------------------- diff --git a/src/Tools/NnTrainer/Makefile b/src/Tools/NnTrainer/Makefile index f2df495ce..6a9a227bc 100644 --- a/src/Tools/NnTrainer/Makefile +++ b/src/Tools/NnTrainer/Makefile @@ -48,6 +48,9 @@ NN_TRAINER_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) LDFLAGS += $(TF_LDFLAGS) endif +ifdef MODULE_LM_ONNX +NN_TRAINER_O += ../../Onnx/libSprintOnnx.$(a) +endif # ----------------------------------------------------------------------------- diff --git a/src/Tools/SpeechRecognizer/Makefile b/src/Tools/SpeechRecognizer/Makefile index c8b79024c..0c7607c62 100644 --- a/src/Tools/SpeechRecognizer/Makefile +++ b/src/Tools/SpeechRecognizer/Makefile @@ -22,6 +22,8 @@ COMMON_O = ../../Speech/libSprintSpeech.$(a) \ ../../Math/libSprintMath.$(a) \ ../../Math/Lapack/libSprintMathLapack.$(a) \ ../../Core/libSprintCore.$(a) \ + ../../Flf/libSprintFlf.$(a) \ + ../../Flf/FlfCore/libSprintFlfCore.$(a) \ ../../Fsa/libSprintFsa.$(a) COMMON_O += $(subst src,../..,$(LIBS_SEARCH)) @@ -29,8 +31,8 @@ COMMON_O += $(subst src,../..,$(LIBS_SEARCH)) ifdef MODULE_CART COMMON_O += ../../Cart/libSprintCart.$(a) endif -ifdef MODULE_FLF_CORE -COMMON_O += ../../Flf/FlfCore/libSprintFlfCore.$(a) +ifdef MODULE_FLF_EXT +COMMON_O += ../../Flf/FlfExt/libSprintFlfExt.$(a) endif ifdef MODULE_OPENFST ifeq ($(OS),darwin) @@ -56,7 +58,7 @@ endif ifdef MODULE_TENSORFLOW COMMON_O += ../../Tensorflow/libSprintTensorflow.$(a) CXXFLAGS += $(TF_CXXFLAGS) -LDFLAGS += $(TF_LDFLAGS) +LDFLAGS := $(TF_LDFLAGS) $(LDFLAGS) endif # -----------------------------------------------------------------------------