Skip to content

Commit 6a9ac9c

Browse files
larissaklcurufinweSimon Berger
authored
Refactor tree construction in AdvancedTreeSearch to allow multiple TreeBuilder types (#85)
Co-authored-by: Eugen Beck <[email protected]> Co-authored-by: Simon Berger <[email protected]>
1 parent 8dfd39b commit 6a9ac9c

File tree

6 files changed

+808
-700
lines changed

6 files changed

+808
-700
lines changed

src/Search/AdvancedTreeSearch/PersistentStateTree.cc

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct ConvertTree {
4141
TreeIndex masterTreeIndex;
4242
StateId rootSubTree;
4343
StateId ciRootNode;
44-
std::map<StateTree::Exit, u32> exits; //Maps exits to label-indices @todo Make this a hash_map
44+
std::map<StateTree::Exit, u32> exits; // Maps exits to label-indices @todo Make this a hash_map
4545
std::vector<PersistentStateTree::Exit> exitVector;
4646
Core::HashMap<StateId, StateTree::StateId> statesForNodes;
4747
Core::HashMap<StateTree::StateId, StateId> nodesForStates;
@@ -73,7 +73,7 @@ struct ConvertTree {
7373
}
7474
}
7575

76-
///Make sure a node is created for every single state, so that also the coarticulated roots are respected
76+
/// Make sure a node is created for every single state, so that also the coarticulated roots are respected
7777

7878
for (std::set<StateTree::StateId>::iterator stateIt = coarticulatedRootStates.begin(); stateIt != coarticulatedRootStates.end(); ++stateIt) {
7979
StateTree::StateId state = *stateIt;
@@ -121,7 +121,7 @@ struct ConvertTree {
121121
exitIndices.insert(exitEntry->second);
122122
}
123123

124-
//Add connections to the attached outputs/exits
124+
// Add connections to the attached outputs/exits
125125
for (std::set<u32>::iterator it = exitIndices.begin(); it != exitIndices.end(); ++it)
126126
subtrees.addOutputToEdge(subtrees.state(node).successors, *it);
127127
}
@@ -150,10 +150,10 @@ struct ConvertTree {
150150

151151
subtrees.state(node).stateDesc = state;
152152

153-
//Build successor structure
153+
// Build successor structure
154154
std::pair<StateTree::SuccessorIterator, StateTree::SuccessorIterator> successors = tree->successors(stateId);
155155

156-
StateId current = node; //Just to verify the order
156+
StateId current = node; // Just to verify the order
157157

158158
for (; successors.first != successors.second; ++successors.first) {
159159
std::unordered_map<StateTree::StateId, StateId>::iterator nodeIt = nodesForStates.find(*successors.first);
@@ -166,14 +166,15 @@ struct ConvertTree {
166166
}
167167
};
168168

169-
PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon)
169+
PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory)
170170
: masterTree(0),
171171
rootState(0),
172172
ciRootState(0),
173173
archive_(paramCacheArchive(Core::Configuration(config, "search-network"))),
174174
acousticModel_(acousticModel),
175175
lexicon_(lexicon),
176-
config_(config) {
176+
config_(config),
177+
treeBuilderFactory_(treeBuilderFactory) {
177178
if (acousticModel_.get() && lexicon_.get()) {
178179
const Am::ClassicAcousticModel* am = required_cast(const Am::ClassicAcousticModel*, acousticModel.get());
179180
Core::DependencySet d;
@@ -320,7 +321,7 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) {
320321
in >> masterTree >> dependenciesChecksum;
321322

322323
if (dependenciesChecksum != dependencies_.getChecksum()) {
323-
Core::Application::us()->log() << "dependencies of the network image don't equal the requiered dependencies with checksum " << dependenciesChecksum;
324+
Core::Application::us()->log() << "dependencies of the network image don't equal the required dependencies with checksum " << dependenciesChecksum;
324325
return false;
325326
}
326327

@@ -436,7 +437,7 @@ HMMStateNetwork::CleanupResult PersistentStateTree::cleanup(bool cleanupExits) {
436437

437438
Core::HashMap<StateId, StateId>::const_iterator targetNodeIt;
438439
if (rootState) {
439-
verify(cleanupResult.nodeMap.find(rootState) != cleanupResult.nodeMap.end()); //Root-node must stay unchanged
440+
verify(cleanupResult.nodeMap.find(rootState) != cleanupResult.nodeMap.end()); // Root-node must stay unchanged
440441
verify(cleanupResult.nodeMap.find(rootState)->second == rootState);
441442
targetNodeIt = cleanupResult.nodeMap.find(rootState);
442443
verify(targetNodeIt != cleanupResult.nodeMap.end());
@@ -512,7 +513,7 @@ void PersistentStateTree::dumpDotGraph(std::string file, const std::vector<int>&
512513
int depth = 0;
513514
if (!nodeDepths.empty())
514515
depth = nodeDepths[node];
515-
os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d", node, node, depth, structure.state(node).stateDesc.acousticModel);
516+
os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d\\nt=%d", node, node, depth, structure.state(node).stateDesc.acousticModel, structure.state(node).stateDesc.transitionModelIndex);
516517

517518
for (HMMStateNetwork::SuccessorIterator target = structure.successors(node); target; ++target)
518519
if (target.isLabel() && exits[target.label()].pronunciation != Bliss::LemmaPronunciation::invalidId)

src/Search/AdvancedTreeSearch/PersistentStateTree.hh

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,34 +31,38 @@ struct MyStandardValueHash {
3131
}
3232
};
3333

34+
class AbstractTreeBuilder;
35+
3436
namespace Search {
3537
class HMMStateNetwork;
3638
class StateTree;
3739

3840
class PersistentStateTree {
3941
public:
42+
using TreeBuilderFactory = std::function<std::unique_ptr<AbstractTreeBuilder>(Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>;
43+
4044
///@param lexicon This must be given if the resulting exits are supposed to be functional
41-
PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon);
45+
PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory);
4246

43-
///Builds this state tree.
47+
/// Builds this state tree.
4448
void build();
4549

46-
///Writes the current state of the state tree into the file,
47-
///Returns whether writing was successful
50+
/// Writes the current state of the state tree into the file,
51+
/// Returns whether writing was successful
4852
bool write(int transformation = 0);
4953

50-
///Reads the state tree from the file.
54+
/// Reads the state tree from the file.
5155
///@return Whether the reading was successful.
5256
bool read(int transformation = 0);
5357

54-
///Cleans up the structure, saving memory and allowing a more efficient iteration.
55-
///Node and tree IDs may be changed.
58+
/// Cleans up the structure, saving memory and allowing a more efficient iteration.
59+
/// Node and tree IDs may be changed.
5660
///@return An object that contains a mapping representing the index changes.
5761
HMMStateNetwork::CleanupResult cleanup(bool cleanupExits = true);
5862

59-
///Removes all outputs from the network
60-
///Also performs a cleanup, so the search network must already be clean
61-
///for indices to stay equal
63+
/// Removes all outputs from the network
64+
/// Also performs a cleanup, so the search network must already be clean
65+
/// for indices to stay equal
6266
void removeOutputs();
6367

6468
u32 getChecksum() const;
@@ -128,11 +132,12 @@ private:
128132
Core::Ref<const Am::AcousticModel> acousticModel_;
129133
Bliss::LexiconRef lexicon_;
130134
Core::Configuration config_;
135+
TreeBuilderFactory treeBuilderFactory_;
131136

132-
//Writes the whole state network into the given stream
137+
// Writes the whole state network into the given stream
133138
void write(Core::MappedArchiveWriter writer);
134139

135-
//Reads the state network from the given stream.
140+
// Reads the state network from the given stream.
136141
//@return Whether the reading was successful.
137142
bool read(Core::MappedArchiveReader reader);
138143
};

0 commit comments

Comments
 (0)