Skip to content

Commit 776a088

Browse files
committed
RequirementMachine: Parametrize Trie by value type and add support for longest-suffix match
The PropertyMap wants to use a try to map to PropertyBags, and it needs longest-suffix rather than shortest-prefix matching. Implement both by making Trie into a template class with two parameters; the ValueType, and the MatchKind. Note that while the MatchKind encodes the longest vs shortest match part, matching on the prefix vs suffix of a term is up to the caller, since the find() and insert() methods of Trie take a pair of iterators, so simply passing in begin()/end() vs rbegin()/rend() selects the direction.
1 parent 9f71ee1 commit 776a088

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ RewriteSystem::RewriteSystem(RewriteContext &ctx)
3232
}
3333

3434
RewriteSystem::~RewriteSystem() {
35-
RuleTrie.updateHistograms(Context.RuleTrieHistogram,
36-
Context.RuleTrieRootHistogram);
35+
Trie.updateHistograms(Context.RuleTrieHistogram,
36+
Context.RuleTrieRootHistogram);
3737
}
3838

3939
void Rule::dump(llvm::raw_ostream &out) const {
@@ -99,7 +99,7 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
9999

100100
unsigned i = Rules.size();
101101
Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context));
102-
auto oldRuleID = RuleTrie.insert(lhs.begin(), lhs.end(), i);
102+
auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), i);
103103
if (oldRuleID) {
104104
llvm::errs() << "Duplicate rewrite rule!\n";
105105
const auto &oldRule = Rules[*oldRuleID];
@@ -167,7 +167,7 @@ bool RewriteSystem::simplify(MutableTerm &term) const {
167167
auto from = term.begin();
168168
auto end = term.end();
169169
while (from < end) {
170-
auto ruleID = RuleTrie.find(from, end);
170+
auto ruleID = Trie.find(from, end);
171171
if (ruleID) {
172172
const auto &rule = Rules[*ruleID];
173173
if (!rule.isDeleted()) {

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ class RewriteSystem final {
9898
/// as rules introduced by the completion procedure.
9999
std::vector<Rule> Rules;
100100

101-
/// A prefix trie of rule left hand sides to optimize lookup.
102-
Trie RuleTrie;
101+
/// A prefix trie of rule left hand sides to optimize lookup. The value
102+
/// type is an index into the Rules array defined above.
103+
Trie<unsigned, MatchKind::Shortest> Trie;
103104

104105
/// The graph of all protocols transitively referenced via our set of
105106
/// rewrite rules, used for the linear order on symbols.

lib/AST/RequirementMachine/Trie.h

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,18 @@ namespace swift {
2020

2121
namespace rewriting {
2222

23+
enum class MatchKind {
24+
Shortest,
25+
Longest
26+
};
27+
28+
template<typename ValueType, MatchKind Kind>
2329
class Trie {
2430
public:
2531
struct Node;
2632

2733
struct Entry {
28-
Optional<unsigned> RuleID;
34+
Optional<ValueType> Value;
2935
Node *Children = nullptr;
3036
};
3137

@@ -47,8 +53,10 @@ class Trie {
4753
rootStats.add(Root.Entries.size());
4854
}
4955

50-
/// The destructor deletes all nodes.
51-
~Trie() {
56+
/// Delete all entries from the trie.
57+
void clear() {
58+
Root.Entries.clear();
59+
5260
for (auto iter = Nodes.rbegin(); iter != Nodes.rend(); ++iter) {
5361
auto *node = *iter;
5462
delete node;
@@ -57,12 +65,16 @@ class Trie {
5765
Nodes.clear();
5866
}
5967

68+
~Trie() {
69+
clear();
70+
}
71+
6072
/// Insert an entry with the key given by the range [begin, end).
6173
/// Returns the old value if the trie already had an entry for this key;
6274
/// this is actually an invariant violation, but we can produce a better
6375
/// assertion further up the stack.
6476
template<typename Iter>
65-
Optional<unsigned> insert(Iter begin, Iter end, unsigned ruleID) {
77+
Optional<ValueType> insert(Iter begin, Iter end, ValueType value) {
6678
assert(begin != end);
6779
auto *node = &Root;
6880

@@ -71,10 +83,10 @@ class Trie {
7183
++begin;
7284

7385
if (begin == end) {
74-
if (entry.RuleID)
75-
return entry.RuleID;
86+
if (entry.Value)
87+
return entry.Value;
7688

77-
entry.RuleID = ruleID;
89+
entry.Value = value;
7890
return None;
7991
}
8092

@@ -87,26 +99,38 @@ class Trie {
8799
}
88100
}
89101

90-
/// Find the shortest prefix of the range given by [begin,end).
102+
/// Find the shortest or longest prefix of the range given by [begin,end),
103+
/// depending on whether the Kind template parameter was bound to
104+
/// MatchKind::Shortest or MatchKind::Longest.
91105
template<typename Iter>
92-
Optional<unsigned>
106+
Optional<ValueType>
93107
find(Iter begin, Iter end) const {
94108
assert(begin != end);
95109
auto *node = &Root;
96110

111+
Optional<ValueType> bestMatch = None;
112+
97113
while (true) {
98114
auto found = node->Entries.find(*begin);
99115
++begin;
100116

101117
if (found == node->Entries.end())
102-
return None;
118+
return bestMatch;
103119

104120
const auto &entry = found->second;
105-
if (begin == end || entry.RuleID)
106-
return entry.RuleID;
121+
122+
if (entry.Value) {
123+
if (Kind == MatchKind::Shortest)
124+
return entry.Value;
125+
126+
bestMatch = entry.Value;
127+
}
128+
129+
if (begin == end)
130+
return bestMatch;
107131

108132
if (entry.Children == nullptr)
109-
return None;
133+
return bestMatch;
110134

111135
node = entry.Children;
112136
}

0 commit comments

Comments
 (0)