Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Search/Module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const Core::Choice choiceTreeBuilderType(
"minimized-hmm", static_cast<int>(TreeBuilderType::minimizedHmm),
"ctc", static_cast<int>(TreeBuilderType::ctc),
"rna", static_cast<int>(TreeBuilderType::rna),
"aed", static_cast<int>(TreeBuilderType::aed),
Core::Choice::endMark());

const Core::ParameterChoice paramTreeBuilderType(
Expand All @@ -70,6 +71,9 @@ std::unique_ptr<AbstractTreeBuilder> Module_::createTreeBuilder(Core::Configurat
case Search::TreeBuilderType::rna: {
return std::unique_ptr<AbstractTreeBuilder>(new RnaTreeBuilder(config, lexicon, acousticModel, network, initialize));
} break;
case Search::TreeBuilderType::aed: {
return std::unique_ptr<AbstractTreeBuilder>(new AedTreeBuilder(config, lexicon, acousticModel, network, initialize));
} break;
default: defect();
}
}
Expand Down
1 change: 1 addition & 0 deletions src/Search/Module.hh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum TreeBuilderType {
minimizedHmm,
ctc,
rna,
aed,
};

enum SearchType {
Expand Down
232 changes: 178 additions & 54 deletions src/Search/TreeBuilder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,67 @@ inline void MinimizedTreeBuilder::mapSuccessors(const std::set<StateId>& success
}
}

// -------------------- SharedBaseClassTreeBuilder --------------------

SharedBaseClassTreeBuilder::SharedBaseClassTreeBuilder(Core::Configuration config,
const Bliss::Lexicon& lexicon,
const Am::AcousticModel& acousticModel,
Search::PersistentStateTree& network)
: AbstractTreeBuilder(config, lexicon, acousticModel, network) {}

StateId SharedBaseClassTreeBuilder::createRoot() {
return createState(StateTree::StateDesc(Search::StateTree::invalidAcousticModel, Am::TransitionModel::entryM1));
}

StateId SharedBaseClassTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc desc) {
// Check if the successor already exists
for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) {
if (!target.isLabel() && network_.structure.state(*target).stateDesc == desc) {
return *target;
}
}

// No matching successor found, extend
StateId ret = createState(desc);
network_.structure.addTargetToNode(predecessor, ret);
return ret;
}

void SharedBaseClassTreeBuilder::addTransition(StateId predecessor, StateId successor) {
auto const& predecessorStateDesc = network_.structure.state(predecessor).stateDesc;
auto const& successorStateDesc = network_.structure.state(successor).stateDesc;

for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) {
if (!target.isLabel() && network_.structure.state(*target).stateDesc == successorStateDesc) {
// The node is already a successor of the predecessor, so the transition already exists
return;
}
}

// The transition does not exists yet, add it
network_.structure.addTargetToNode(predecessor, successor);
}

u32 SharedBaseClassTreeBuilder::addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron) {
PersistentStateTree::Exit exit;
exit.transitState = transitState;
exit.pronunciation = pron;

u32 exitIndex = createExit(exit);

// Check if the exit is already a successor
// This should only happen if the same lemma is contained multiple times in the lexicon
for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(state); target; ++target) {
if (target.isLabel() && target.label() == exitIndex) {
return exitIndex;
}
}

// The exit is not part of the successors yet, add it
network_.structure.addOutputToNode(state, ID_FROM_LABEL(exitIndex));
return exitIndex;
}

// -------------------- CtcTreeBuilder --------------------

const Core::ParameterBool CtcTreeBuilder::paramLabelLoop(
Expand All @@ -1221,7 +1282,7 @@ const Core::ParameterBool CtcTreeBuilder::paramForceBlank(
true);

CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize)
: AbstractTreeBuilder(config, lexicon, acousticModel, network),
: SharedBaseClassTreeBuilder(config, lexicon, acousticModel, network),
labelLoop_(paramLabelLoop(config)),
blankLoop_(paramBlankLoop(config)),
forceBlank_(paramForceBlank(config)) {
Expand Down Expand Up @@ -1282,59 +1343,6 @@ void CtcTreeBuilder::build() {
}
}

StateId CtcTreeBuilder::createRoot() {
return createState(StateTree::StateDesc(Search::StateTree::invalidAcousticModel, Am::TransitionModel::entryM1));
}

u32 CtcTreeBuilder::addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron) {
PersistentStateTree::Exit exit;
exit.transitState = transitState;
exit.pronunciation = pron;

u32 exitIndex = createExit(exit);

// Check if the exit is already a successor
// This should only happen if the same lemma is contained multiple times in the lexicon
for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(state); target; ++target) {
if (target.isLabel() && target.label() == exitIndex) {
return exitIndex;
}
}

// The exit is not part of the successors yet, add it
network_.structure.addOutputToNode(state, ID_FROM_LABEL(exitIndex));
return exitIndex;
}

StateId CtcTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc desc) {
// Check if the successor already exists
for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) {
if (!target.isLabel() && network_.structure.state(*target).stateDesc == desc) {
return *target;
}
}

// No matching successor found, extend
StateId ret = createState(desc);
network_.structure.addTargetToNode(predecessor, ret);
return ret;
}

void CtcTreeBuilder::addTransition(StateId predecessor, StateId successor) {
auto const& predecessorStateDesc = network_.structure.state(predecessor).stateDesc;
auto const& successorStateDesc = network_.structure.state(successor).stateDesc;

for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) {
if (!target.isLabel() && network_.structure.state(*target).stateDesc == successorStateDesc) {
// The node is already a successor of the predecessor, so the transition already exists
return;
}
}

// The transition does not exists yet, add it
network_.structure.addTargetToNode(predecessor, successor);
}

StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronunciation const* pron) {
require(pron != nullptr);

Expand Down Expand Up @@ -1453,3 +1461,119 @@ RnaTreeBuilder::RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon&
this->labelLoop_ = paramLabelLoop(config);
this->forceBlank_ = paramForceBlank(config);
}

// -------------------- AedTreeBuilder --------------------

AedTreeBuilder::AedTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize)
: SharedBaseClassTreeBuilder(config, lexicon, acousticModel, network) {
auto iters = lexicon.phonemeInventory()->phonemes();
for (auto it = iters.first; it != iters.second; ++it) {
require(not(*it)->isContextDependent()); // Context dependent labels are not supported
}

if (initialize) {
verify(!network_.rootState);
network_.ciRootState = network_.rootState = createRoot();

// Create a special root for the word-boundary token if it exists in the lexicon
if (lexicon.specialLemma("word-boundary") != nullptr) {
wordBoundaryRoot_ = createRoot();
network_.otherRootStates.insert(wordBoundaryRoot_);
}
}
}

std::unique_ptr<AbstractTreeBuilder> AedTreeBuilder::newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) {
return std::unique_ptr<AbstractTreeBuilder>(new AedTreeBuilder(config, lexicon, acousticModel, network));
}

void AedTreeBuilder::build() {
auto wordBoundaryLemma = lexicon_.specialLemma("word-boundary");
if (wordBoundaryLemma != nullptr) {
addWordBoundaryStates();
}

auto sentenceEndLemma = lexicon_.specialLemma("sentence-end");
if (!sentenceEndLemma) {
sentenceEndLemma = lexicon_.specialLemma("sentence-boundary");
}
require(sentenceEndLemma);
auto silenceLemma = lexicon_.specialLemma("silence");
auto iters = lexicon_.lemmaPronunciations();

// Iterate over the lemmata and add them to the tree
for (auto it = iters.first; it != iters.second; ++it) {
if ((*it)->lemma() == wordBoundaryLemma) {
// The wordBoundaryLemma should be a successor of the wordBoundaryRoot_
// This is handled separately in addWordBoundaryStates()
continue;
}

StateId lastState = extendPronunciation(network_.rootState, (*it)->pronunciation());

if (wordBoundaryLemma != nullptr && (*it)->lemma() != sentenceEndLemma && (*it)->lemma() != silenceLemma) {
// If existing, the wordBoundaryRoot_ should be the transit state for all word ends except sentence-end and silence
addExit(lastState, wordBoundaryRoot_, (*it)->id());
}
else {
addExit(lastState, network_.rootState, (*it)->id());
}
}
}

StateId AedTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronunciation const* pron) {
require(pron != nullptr);
StateId currentState = startState;

for (u32 i = 0u; i < pron->length(); i++) {
Bliss::Phoneme::Id phoneme = (*pron)[i];

u32 boundary = 0u;
if (i == 0) {
boundary |= Am::Allophone::isInitialPhone;
}
if ((i + 1) == pron->length()) {
boundary |= Am::Allophone::isFinalPhone;
}

Bliss::ContextPhonology::SemiContext history, future;
const Am::Allophone* allophone = acousticModel_.allophoneAlphabet()->allophone(Am::Allophone(Bliss::ContextPhonology::PhonemeInContext(phoneme, history, future), boundary));
const Am::ClassicHmmTopology* hmmTopology = acousticModel_.hmmTopology(phoneme);

for (u32 phoneState = 0; phoneState < hmmTopology->nPhoneStates(); ++phoneState) {
Am::AllophoneState alloState = acousticModel_.allophoneStateAlphabet()->allophoneState(allophone, phoneState);
StateTree::StateDesc desc;
desc.acousticModel = acousticModel_.emissionIndex(alloState); // state-tying look-up

for (u32 subState = 0; subState < hmmTopology->nSubStates(); ++subState) {
desc.transitionModelIndex = acousticModel_.stateTransitionIndex(alloState, subState);
verify(desc.transitionModelIndex < Core::Type<StateTree::StateDesc::TransitionModelIndex>::max);

// Add new state
currentState = extendState(currentState, desc);
}
}
}

return currentState;
}

void AedTreeBuilder::addWordBoundaryStates() {
Bliss::Lemma const* wordBoundaryLemma = lexicon_.specialLemma("word-boundary");
if (wordBoundaryLemma == nullptr) {
return;
}

// Add the word-boundary to the tree, starting from the wordBoundaryRoot_
// If the word-boundary has several pronunciation, only the first one is considered
auto prons = wordBoundaryLemma->pronunciations();

StateId wordBoundaryEnd = extendPronunciation(wordBoundaryRoot_, (prons.first)->pronunciation());
require(wordBoundaryEnd != 0);

Bliss::LemmaPronunciation const* wordBoundaryPronLemma = prons.first;
require(wordBoundaryPronLemma != nullptr);

// The "normal" root is the transition state from the word-boundary token, such that a new word can be started afterwards
addExit(wordBoundaryEnd, network_.rootState, wordBoundaryPronLemma->id());
}
61 changes: 44 additions & 17 deletions src/Search/TreeBuilder.hh
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ protected:

ExitHash exitHash_;

// Allocate a new tree node with the StateDesc `desc`
StateId createState(Search::StateTree::StateDesc desc);
u32 createExit(Search::PersistentStateTree::Exit exit);
// Create a new exit state if it does not exist yet
u32 createExit(Search::PersistentStateTree::Exit exit);
};

class MinimizedTreeBuilder : public AbstractTreeBuilder {
Expand Down Expand Up @@ -247,7 +249,26 @@ protected:
void mapSuccessors(const std::set<StateId>&, std::set<StateId>&, const std::vector<StateId>&, const std::vector<u32>&);
};

class CtcTreeBuilder : public AbstractTreeBuilder {
class SharedBaseClassTreeBuilder : public AbstractTreeBuilder {
public:
SharedBaseClassTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network);
virtual ~SharedBaseClassTreeBuilder() = default;

protected:
// Create a node with invalid AM and TM indices which serves as a root
StateId createRoot();
// Check if a node with StateDesc `desc` is already a successor of the state with ID `predecessor` and add it if not.
// Returns the ID of the successor state.
StateId extendState(StateId predecessor, Search::StateTree::StateDesc desc);
// Add a transition between two already existing states `predecessor` and `successor`, used to insert loops and skip-transitions
void addTransition(StateId predecessor, StateId successor);
// Add an exit from the last state `state` of a word with pronunciation `pron` leading to root node `transitState`.
// The exit is appended to `state`'s successors.
// Returns the ID of the exit.
u32 addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron);
};

class CtcTreeBuilder : public SharedBaseClassTreeBuilder {
public:
static const Core::ParameterBool paramLabelLoop;
static const Core::ParameterBool paramBlankLoop;
Expand All @@ -270,25 +291,10 @@ protected:
Search::StateTree::StateDesc blankDesc_;
Am::AllophoneStateIndex blankAllophoneStateIndex_;

// Create a node with invalid AM and TM indices which serves as a root
StateId createRoot();

// Add an exit from the last state `state` of a word with pronunciation `pron` leading to root node `transitState`.
// The exit is appended to `state`'s successors.
// Returns the ID of the exit.
u32 addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron);

// Check if a node with StateDesc `desc` is already a successor of the state with ID `predecessor` and add it if not.
// Returns the ID of the successor state.
StateId extendState(StateId predecessor, Search::StateTree::StateDesc desc);

// Starting in `startState` (usually a root), include the lemma with pronunciation `pron` in the tree
// Returns the last state corresponding to `pron`.
StateId extendPronunciation(StateId startState, Bliss::Pronunciation const* pron);

// Add a transition between two already existing states `predecessor` and `successor`, used to insert loops and skip-transitions
void addTransition(StateId predecessor, StateId successor);

// Build the sub-tree with the word-boundary lemma plus optional blank starting from `wordBoundaryRoot_`.
void addWordBoundaryStates();
};
Expand All @@ -302,4 +308,25 @@ public:
virtual ~RnaTreeBuilder() = default;
};

class AedTreeBuilder : public SharedBaseClassTreeBuilder {
public:
AedTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true);
virtual ~AedTreeBuilder() = default;

virtual std::unique_ptr<AbstractTreeBuilder> newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true);

// Build a new persistent state network.
virtual void build();

protected:
StateId wordBoundaryRoot_;

// Starting in `startState` (usually a root), include the lemma with pronunciation `pron` in the tree
// Returns the last state corresponding to `pron`.
StateId extendPronunciation(StateId startState, Bliss::Pronunciation const* pron);

// Build the sub-tree with the word-boundary lemma starting from `wordBoundaryRoot_`.
void addWordBoundaryStates();
};

#endif