Skip to content

Commit 0935952

Browse files
committed
RequirementMachine: Cache result of mergeAssociatedTypes()
1 parent 05645a1 commit 0935952

File tree

4 files changed

+31
-7
lines changed

4 files changed

+31
-7
lines changed

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class RewriteContext final {
4848
/// Cache for associated type declarations.
4949
llvm::DenseMap<Symbol, AssociatedTypeDecl *> AssocTypes;
5050

51+
/// Cache for merged associated type symbols.
52+
llvm::DenseMap<std::pair<Symbol, Symbol>, Symbol> MergedAssocTypes;
53+
5154
RewriteContext(const RewriteContext &) = delete;
5255
RewriteContext(RewriteContext &&) = delete;
5356
RewriteContext &operator=(const RewriteContext &) = delete;
@@ -91,6 +94,9 @@ class RewriteContext final {
9194
AssociatedTypeDecl *getAssociatedTypeForSymbol(Symbol symbol,
9295
const ProtocolGraph &protos);
9396

97+
Symbol mergeAssociatedTypes(Symbol lhs, Symbol rhs,
98+
const ProtocolGraph &protos);
99+
94100
~RewriteContext();
95101
};
96102

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ class RewriteSystem final {
173173
computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
174174
const Rule &lhs, const Rule &rhs) const;
175175

176-
Symbol mergeAssociatedTypes(Symbol lhs, Symbol rhs) const;
177176
void processMergedAssociatedTypes();
178177
};
179178

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <deque>
3535
#include <vector>
3636

37+
#include "RewriteContext.h"
3738
#include "RewriteSystem.h"
3839

3940
using namespace swift;
@@ -72,12 +73,23 @@ Symbol Symbol::prependPrefixToConcreteSubstitutions(
7273
/// - If P inherits from Q, this is just [P:T].
7374
/// - If Q inherits from P, this is just [Q:T].
7475
/// - If P and Q are unrelated, this is [P&Q:T].
75-
Symbol RewriteSystem::mergeAssociatedTypes(Symbol lhs, Symbol rhs) const {
76+
///
77+
/// Note that the protocol graph is not part of the caching key; each
78+
/// protocol graph is a subgraph of the global inheritance graph, so
79+
/// the specific choice of subgraph does not change the result.
80+
Symbol RewriteContext::mergeAssociatedTypes(Symbol lhs, Symbol rhs,
81+
const ProtocolGraph &graph) {
82+
auto key = std::make_pair(lhs, rhs);
83+
84+
auto found = MergedAssocTypes.find(key);
85+
if (found != MergedAssocTypes.end())
86+
return found->second;
87+
7688
// Check preconditions that were established by RewriteSystem::addRule().
7789
assert(lhs.getKind() == Symbol::Kind::AssociatedType);
7890
assert(rhs.getKind() == Symbol::Kind::AssociatedType);
7991
assert(lhs.getName() == rhs.getName());
80-
assert(lhs.compare(rhs, Protos) > 0);
92+
assert(lhs.compare(rhs, graph) > 0);
8193

8294
auto protos = lhs.getProtocols();
8395
auto otherProtos = rhs.getProtocols();
@@ -92,7 +104,7 @@ Symbol RewriteSystem::mergeAssociatedTypes(Symbol lhs, Symbol rhs) const {
92104
std::back_inserter(newProtos),
93105
[&](const ProtocolDecl *lhs,
94106
const ProtocolDecl *rhs) -> int {
95-
return Protos.compareProtocols(lhs, rhs) < 0;
107+
return graph.compareProtocols(lhs, rhs) < 0;
96108
});
97109

98110
// Prune duplicates and protocols that are inherited by other
@@ -101,7 +113,7 @@ Symbol RewriteSystem::mergeAssociatedTypes(Symbol lhs, Symbol rhs) const {
101113
for (const auto *newProto : newProtos) {
102114
auto inheritsFrom = [&](const ProtocolDecl *thisProto) {
103115
return (thisProto == newProto ||
104-
Protos.inheritsFrom(thisProto, newProto));
116+
graph.inheritsFrom(thisProto, newProto));
105117
};
106118

107119
if (std::find_if(minimalProtos.begin(), minimalProtos.end(),
@@ -120,7 +132,12 @@ Symbol RewriteSystem::mergeAssociatedTypes(Symbol lhs, Symbol rhs) const {
120132
// of the two sets.
121133
assert(minimalProtos.size() <= protos.size() + otherProtos.size());
122134

123-
return Symbol::forAssociatedType(minimalProtos, lhs.getName(), Context);
135+
auto result = Symbol::forAssociatedType(minimalProtos, lhs.getName(), *this);
136+
auto inserted = MergedAssocTypes.insert(std::make_pair(key, result));
137+
assert(inserted.second);
138+
(void) inserted;
139+
140+
return result;
124141
}
125142

126143
/// Consider the following example:
@@ -207,7 +224,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
207224
llvm::dbgs() << lhs << " => " << rhs << "\n";
208225
}
209226

210-
auto mergedSymbol = mergeAssociatedTypes(lhs.back(), rhs.back());
227+
auto mergedSymbol = Context.mergeAssociatedTypes(lhs.back(), rhs.back(),
228+
Protos);
211229
if (DebugMerge) {
212230
llvm::dbgs() << "### Merged symbol " << mergedSymbol << "\n";
213231
}

lib/AST/RequirementMachine/Symbol.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class LayoutConstraint;
3030
namespace rewriting {
3131

3232
class MutableTerm;
33+
class ProtocolGraph;
3334
class RewriteContext;
3435
class Term;
3536

0 commit comments

Comments
 (0)