diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 31e06e4c6..7fc2a47bf 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -47,6 +47,7 @@ const Core::Choice choiceTreeBuilderType( "minimized-hmm", static_cast(TreeBuilderType::minimizedHmm), "ctc", static_cast(TreeBuilderType::ctc), "rna", static_cast(TreeBuilderType::rna), + "aed", static_cast(TreeBuilderType::aed), Core::Choice::endMark()); const Core::ParameterChoice paramTreeBuilderType( @@ -70,6 +71,9 @@ std::unique_ptr Module_::createTreeBuilder(Core::Configurat case Search::TreeBuilderType::rna: { return std::unique_ptr(new RnaTreeBuilder(config, lexicon, acousticModel, network, initialize)); } break; + case Search::TreeBuilderType::aed: { + return std::unique_ptr(new AedTreeBuilder(config, lexicon, acousticModel, network, initialize)); + } break; default: defect(); } } diff --git a/src/Search/Module.hh b/src/Search/Module.hh index f6e4211e2..d0a35bff2 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -32,6 +32,7 @@ enum TreeBuilderType { minimizedHmm, ctc, rna, + aed, }; enum SearchType { diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index 62175bccf..b33650283 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1203,6 +1203,67 @@ inline void MinimizedTreeBuilder::mapSuccessors(const std::set& success } } +// -------------------- SharedBaseClassTreeBuilder -------------------- + +SharedBaseClassTreeBuilder::SharedBaseClassTreeBuilder(Core::Configuration config, + const Bliss::Lexicon& lexicon, + const Am::AcousticModel& acousticModel, + Search::PersistentStateTree& network) + : AbstractTreeBuilder(config, lexicon, acousticModel, network) {} + +StateId SharedBaseClassTreeBuilder::createRoot() { + return createState(StateTree::StateDesc(Search::StateTree::invalidAcousticModel, Am::TransitionModel::entryM1)); +} + +StateId SharedBaseClassTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc desc) { + // Check if the successor already exists + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { + if (!target.isLabel() && network_.structure.state(*target).stateDesc == desc) { + return *target; + } + } + + // No matching successor found, extend + StateId ret = createState(desc); + network_.structure.addTargetToNode(predecessor, ret); + return ret; +} + +void SharedBaseClassTreeBuilder::addTransition(StateId predecessor, StateId successor) { + auto const& predecessorStateDesc = network_.structure.state(predecessor).stateDesc; + auto const& successorStateDesc = network_.structure.state(successor).stateDesc; + + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { + if (!target.isLabel() && network_.structure.state(*target).stateDesc == successorStateDesc) { + // The node is already a successor of the predecessor, so the transition already exists + return; + } + } + + // The transition does not exists yet, add it + network_.structure.addTargetToNode(predecessor, successor); +} + +u32 SharedBaseClassTreeBuilder::addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron) { + PersistentStateTree::Exit exit; + exit.transitState = transitState; + exit.pronunciation = pron; + + u32 exitIndex = createExit(exit); + + // Check if the exit is already a successor + // This should only happen if the same lemma is contained multiple times in the lexicon + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(state); target; ++target) { + if (target.isLabel() && target.label() == exitIndex) { + return exitIndex; + } + } + + // The exit is not part of the successors yet, add it + network_.structure.addOutputToNode(state, ID_FROM_LABEL(exitIndex)); + return exitIndex; +} + // -------------------- CtcTreeBuilder -------------------- const Core::ParameterBool CtcTreeBuilder::paramLabelLoop( @@ -1221,7 +1282,7 @@ const Core::ParameterBool CtcTreeBuilder::paramForceBlank( true); CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) - : AbstractTreeBuilder(config, lexicon, acousticModel, network), + : SharedBaseClassTreeBuilder(config, lexicon, acousticModel, network), labelLoop_(paramLabelLoop(config)), blankLoop_(paramBlankLoop(config)), forceBlank_(paramForceBlank(config)) { @@ -1282,59 +1343,6 @@ void CtcTreeBuilder::build() { } } -StateId CtcTreeBuilder::createRoot() { - return createState(StateTree::StateDesc(Search::StateTree::invalidAcousticModel, Am::TransitionModel::entryM1)); -} - -u32 CtcTreeBuilder::addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron) { - PersistentStateTree::Exit exit; - exit.transitState = transitState; - exit.pronunciation = pron; - - u32 exitIndex = createExit(exit); - - // Check if the exit is already a successor - // This should only happen if the same lemma is contained multiple times in the lexicon - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(state); target; ++target) { - if (target.isLabel() && target.label() == exitIndex) { - return exitIndex; - } - } - - // The exit is not part of the successors yet, add it - network_.structure.addOutputToNode(state, ID_FROM_LABEL(exitIndex)); - return exitIndex; -} - -StateId CtcTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc desc) { - // Check if the successor already exists - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { - if (!target.isLabel() && network_.structure.state(*target).stateDesc == desc) { - return *target; - } - } - - // No matching successor found, extend - StateId ret = createState(desc); - network_.structure.addTargetToNode(predecessor, ret); - return ret; -} - -void CtcTreeBuilder::addTransition(StateId predecessor, StateId successor) { - auto const& predecessorStateDesc = network_.structure.state(predecessor).stateDesc; - auto const& successorStateDesc = network_.structure.state(successor).stateDesc; - - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { - if (!target.isLabel() && network_.structure.state(*target).stateDesc == successorStateDesc) { - // The node is already a successor of the predecessor, so the transition already exists - return; - } - } - - // The transition does not exists yet, add it - network_.structure.addTargetToNode(predecessor, successor); -} - StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronunciation const* pron) { require(pron != nullptr); @@ -1453,3 +1461,119 @@ RnaTreeBuilder::RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon& this->labelLoop_ = paramLabelLoop(config); this->forceBlank_ = paramForceBlank(config); } + +// -------------------- AedTreeBuilder -------------------- + +AedTreeBuilder::AedTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) + : SharedBaseClassTreeBuilder(config, lexicon, acousticModel, network) { + auto iters = lexicon.phonemeInventory()->phonemes(); + for (auto it = iters.first; it != iters.second; ++it) { + require(not(*it)->isContextDependent()); // Context dependent labels are not supported + } + + if (initialize) { + verify(!network_.rootState); + network_.ciRootState = network_.rootState = createRoot(); + + // Create a special root for the word-boundary token if it exists in the lexicon + if (lexicon.specialLemma("word-boundary") != nullptr) { + wordBoundaryRoot_ = createRoot(); + network_.otherRootStates.insert(wordBoundaryRoot_); + } + } +} + +std::unique_ptr AedTreeBuilder::newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) { + return std::unique_ptr(new AedTreeBuilder(config, lexicon, acousticModel, network)); +} + +void AedTreeBuilder::build() { + auto wordBoundaryLemma = lexicon_.specialLemma("word-boundary"); + if (wordBoundaryLemma != nullptr) { + addWordBoundaryStates(); + } + + auto sentenceEndLemma = lexicon_.specialLemma("sentence-end"); + if (!sentenceEndLemma) { + sentenceEndLemma = lexicon_.specialLemma("sentence-boundary"); + } + require(sentenceEndLemma); + auto silenceLemma = lexicon_.specialLemma("silence"); + auto iters = lexicon_.lemmaPronunciations(); + + // Iterate over the lemmata and add them to the tree + for (auto it = iters.first; it != iters.second; ++it) { + if ((*it)->lemma() == wordBoundaryLemma) { + // The wordBoundaryLemma should be a successor of the wordBoundaryRoot_ + // This is handled separately in addWordBoundaryStates() + continue; + } + + StateId lastState = extendPronunciation(network_.rootState, (*it)->pronunciation()); + + if (wordBoundaryLemma != nullptr && (*it)->lemma() != sentenceEndLemma && (*it)->lemma() != silenceLemma) { + // If existing, the wordBoundaryRoot_ should be the transit state for all word ends except sentence-end and silence + addExit(lastState, wordBoundaryRoot_, (*it)->id()); + } + else { + addExit(lastState, network_.rootState, (*it)->id()); + } + } +} + +StateId AedTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronunciation const* pron) { + require(pron != nullptr); + StateId currentState = startState; + + for (u32 i = 0u; i < pron->length(); i++) { + Bliss::Phoneme::Id phoneme = (*pron)[i]; + + u32 boundary = 0u; + if (i == 0) { + boundary |= Am::Allophone::isInitialPhone; + } + if ((i + 1) == pron->length()) { + boundary |= Am::Allophone::isFinalPhone; + } + + Bliss::ContextPhonology::SemiContext history, future; + const Am::Allophone* allophone = acousticModel_.allophoneAlphabet()->allophone(Am::Allophone(Bliss::ContextPhonology::PhonemeInContext(phoneme, history, future), boundary)); + const Am::ClassicHmmTopology* hmmTopology = acousticModel_.hmmTopology(phoneme); + + for (u32 phoneState = 0; phoneState < hmmTopology->nPhoneStates(); ++phoneState) { + Am::AllophoneState alloState = acousticModel_.allophoneStateAlphabet()->allophoneState(allophone, phoneState); + StateTree::StateDesc desc; + desc.acousticModel = acousticModel_.emissionIndex(alloState); // state-tying look-up + + for (u32 subState = 0; subState < hmmTopology->nSubStates(); ++subState) { + desc.transitionModelIndex = acousticModel_.stateTransitionIndex(alloState, subState); + verify(desc.transitionModelIndex < Core::Type::max); + + // Add new state + currentState = extendState(currentState, desc); + } + } + } + + return currentState; +} + +void AedTreeBuilder::addWordBoundaryStates() { + Bliss::Lemma const* wordBoundaryLemma = lexicon_.specialLemma("word-boundary"); + if (wordBoundaryLemma == nullptr) { + return; + } + + // Add the word-boundary to the tree, starting from the wordBoundaryRoot_ + // If the word-boundary has several pronunciation, only the first one is considered + auto prons = wordBoundaryLemma->pronunciations(); + + StateId wordBoundaryEnd = extendPronunciation(wordBoundaryRoot_, (prons.first)->pronunciation()); + require(wordBoundaryEnd != 0); + + Bliss::LemmaPronunciation const* wordBoundaryPronLemma = prons.first; + require(wordBoundaryPronLemma != nullptr); + + // The "normal" root is the transition state from the word-boundary token, such that a new word can be started afterwards + addExit(wordBoundaryEnd, network_.rootState, wordBoundaryPronLemma->id()); +} \ No newline at end of file diff --git a/src/Search/TreeBuilder.hh b/src/Search/TreeBuilder.hh index 78b0c829d..23e97fed3 100644 --- a/src/Search/TreeBuilder.hh +++ b/src/Search/TreeBuilder.hh @@ -54,8 +54,10 @@ protected: ExitHash exitHash_; + // Allocate a new tree node with the StateDesc `desc` StateId createState(Search::StateTree::StateDesc desc); - u32 createExit(Search::PersistentStateTree::Exit exit); + // Create a new exit state if it does not exist yet + u32 createExit(Search::PersistentStateTree::Exit exit); }; class MinimizedTreeBuilder : public AbstractTreeBuilder { @@ -247,7 +249,26 @@ protected: void mapSuccessors(const std::set&, std::set&, const std::vector&, const std::vector&); }; -class CtcTreeBuilder : public AbstractTreeBuilder { +class SharedBaseClassTreeBuilder : public AbstractTreeBuilder { +public: + SharedBaseClassTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network); + virtual ~SharedBaseClassTreeBuilder() = default; + +protected: + // Create a node with invalid AM and TM indices which serves as a root + StateId createRoot(); + // Check if a node with StateDesc `desc` is already a successor of the state with ID `predecessor` and add it if not. + // Returns the ID of the successor state. + StateId extendState(StateId predecessor, Search::StateTree::StateDesc desc); + // Add a transition between two already existing states `predecessor` and `successor`, used to insert loops and skip-transitions + void addTransition(StateId predecessor, StateId successor); + // Add an exit from the last state `state` of a word with pronunciation `pron` leading to root node `transitState`. + // The exit is appended to `state`'s successors. + // Returns the ID of the exit. + u32 addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron); +}; + +class CtcTreeBuilder : public SharedBaseClassTreeBuilder { public: static const Core::ParameterBool paramLabelLoop; static const Core::ParameterBool paramBlankLoop; @@ -270,25 +291,10 @@ protected: Search::StateTree::StateDesc blankDesc_; Am::AllophoneStateIndex blankAllophoneStateIndex_; - // Create a node with invalid AM and TM indices which serves as a root - StateId createRoot(); - - // Add an exit from the last state `state` of a word with pronunciation `pron` leading to root node `transitState`. - // The exit is appended to `state`'s successors. - // Returns the ID of the exit. - u32 addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron); - - // Check if a node with StateDesc `desc` is already a successor of the state with ID `predecessor` and add it if not. - // Returns the ID of the successor state. - StateId extendState(StateId predecessor, Search::StateTree::StateDesc desc); - // Starting in `startState` (usually a root), include the lemma with pronunciation `pron` in the tree // Returns the last state corresponding to `pron`. StateId extendPronunciation(StateId startState, Bliss::Pronunciation const* pron); - // Add a transition between two already existing states `predecessor` and `successor`, used to insert loops and skip-transitions - void addTransition(StateId predecessor, StateId successor); - // Build the sub-tree with the word-boundary lemma plus optional blank starting from `wordBoundaryRoot_`. void addWordBoundaryStates(); }; @@ -302,4 +308,25 @@ public: virtual ~RnaTreeBuilder() = default; }; +class AedTreeBuilder : public SharedBaseClassTreeBuilder { +public: + AedTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + virtual ~AedTreeBuilder() = default; + + virtual std::unique_ptr newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + + // Build a new persistent state network. + virtual void build(); + +protected: + StateId wordBoundaryRoot_; + + // Starting in `startState` (usually a root), include the lemma with pronunciation `pron` in the tree + // Returns the last state corresponding to `pron`. + StateId extendPronunciation(StateId startState, Bliss::Pronunciation const* pron); + + // Build the sub-tree with the word-boundary lemma starting from `wordBoundaryRoot_`. + void addWordBoundaryStates(); +}; + #endif