Skip to content

Commit 1d8dd94

Browse files
committed
RequirementMachine: Refactor RuleBuilder in preparation for rule sharing
1 parent 1f125c3 commit 1d8dd94

File tree

5 files changed

+112
-94
lines changed

5 files changed

+112
-94
lines changed

lib/AST/RequirementMachine/ConcreteTypeWitness.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,8 @@ void PropertyMap::inferConditionalRequirements(
562562
llvm::dbgs() << "@@@ Unknown protocol: "<< proto->getName() << "\n";
563563
}
564564

565-
RuleBuilder builder(Context, System.getProtocolMap());
566-
builder.addProtocol(proto, /*initialComponent=*/false);
565+
RuleBuilder builder(Context, System.getReferencedProtocols());
566+
builder.addReferencedProtocol(proto);
567567
builder.collectRulesFromReferencedProtocols();
568568

569569
for (const auto &rule : builder.PermanentRules)

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 77 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ void RuleBuilder::addRequirements(ArrayRef<Requirement> requirements) {
10511051
// Collect all protocols transitively referenced from these requirements.
10521052
for (auto req : requirements) {
10531053
if (req.getKind() == RequirementKind::Conformance) {
1054-
addProtocol(req.getProtocolDecl(), /*initialComponent=*/false);
1054+
addReferencedProtocol(req.getProtocolDecl());
10551055
}
10561056
}
10571057

@@ -1066,7 +1066,7 @@ void RuleBuilder::addRequirements(ArrayRef<StructuralRequirement> requirements)
10661066
// Collect all protocols transitively referenced from these requirements.
10671067
for (auto req : requirements) {
10681068
if (req.req.getKind() == RequirementKind::Conformance) {
1069-
addProtocol(req.req.getProtocolDecl(), /*initialComponent=*/false);
1069+
addReferencedProtocol(req.req.getProtocolDecl());
10701070
}
10711071
}
10721072

@@ -1077,16 +1077,65 @@ void RuleBuilder::addRequirements(ArrayRef<StructuralRequirement> requirements)
10771077
addRequirement(req, /*proto=*/nullptr);
10781078
}
10791079

1080+
/// For building a rewrite system for a protocol connected component from
1081+
/// user-written requirements. Used when actually building requirement
1082+
/// signatures.
10801083
void RuleBuilder::addProtocols(ArrayRef<const ProtocolDecl *> protos) {
1081-
// Collect all protocols transitively referenced from this connected component
1082-
// of the protocol dependency graph.
1083-
for (auto proto : protos) {
1084-
addProtocol(proto, /*initialComponent=*/true);
1084+
for (auto *proto : protos) {
1085+
ReferencedProtocols.insert(proto);
1086+
}
1087+
1088+
for (auto *proto : protos) {
1089+
if (Dump) {
1090+
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
1091+
}
1092+
1093+
addPermanentProtocolRules(proto);
1094+
1095+
for (auto req : proto->getStructuralRequirements())
1096+
addRequirement(req, proto);
1097+
1098+
for (auto req : proto->getTypeAliasRequirements())
1099+
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1100+
1101+
for (auto *otherProto : proto->getProtocolDependencies())
1102+
addReferencedProtocol(otherProto);
1103+
1104+
if (Dump) {
1105+
llvm::dbgs() << "}\n";
1106+
}
10851107
}
10861108

1109+
// Collect all protocols transitively referenced from this connected component
1110+
// of the protocol dependency graph.
10871111
collectRulesFromReferencedProtocols();
10881112
}
10891113

1114+
/// Add permanent rules for a protocol, consisting of:
1115+
///
1116+
/// - The identity conformance rule [P].[P] => [P].
1117+
/// - An associated type introduction rule for each associated type.
1118+
/// - An inherited associated type introduction rule for each associated
1119+
/// type of each inherited protocol.
1120+
void RuleBuilder::addPermanentProtocolRules(const ProtocolDecl *proto) {
1121+
MutableTerm lhs;
1122+
lhs.add(Symbol::forProtocol(proto, Context));
1123+
lhs.add(Symbol::forProtocol(proto, Context));
1124+
1125+
MutableTerm rhs;
1126+
rhs.add(Symbol::forProtocol(proto, Context));
1127+
1128+
PermanentRules.emplace_back(lhs, rhs);
1129+
1130+
for (auto *assocType : proto->getAssociatedTypeMembers())
1131+
addAssociatedType(assocType, proto);
1132+
1133+
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
1134+
for (auto *assocType : inheritedProto->getAssociatedTypeMembers())
1135+
addAssociatedType(assocType, proto);
1136+
}
1137+
}
1138+
10901139
/// For an associated type T in a protocol P, we add a rewrite rule:
10911140
///
10921141
/// [P].T => [P:T]
@@ -1264,72 +1313,44 @@ void RuleBuilder::addTypeAlias(const ProtocolTypeAlias &alias,
12641313
/*requirementID=*/None);
12651314
}
12661315

1267-
/// Record information about a protocol if we have no seen it yet.
1268-
void RuleBuilder::addProtocol(const ProtocolDecl *proto,
1269-
bool initialComponent) {
1270-
if (ProtocolMap.count(proto) > 0)
1271-
return;
1272-
1273-
ProtocolMap[proto] = initialComponent;
1274-
Protocols.push_back(proto);
1316+
/// If we haven't seen this protocol yet, save it for later so that we can
1317+
/// import the rewrite rules from its connected component.
1318+
void RuleBuilder::addReferencedProtocol(const ProtocolDecl *proto) {
1319+
if (ReferencedProtocols.insert(proto).second)
1320+
ProtocolsToImport.push_back(proto);
12751321
}
12761322

12771323
/// Compute the transitive closure of the set of all protocols referenced from
12781324
/// the right hand sides of conformance requirements, and convert their
12791325
/// requirements to rewrite rules.
12801326
void RuleBuilder::collectRulesFromReferencedProtocols() {
1327+
// Compute the transitive closure.
12811328
unsigned i = 0;
1282-
while (i < Protocols.size()) {
1283-
auto *proto = Protocols[i++];
1329+
while (i < ProtocolsToImport.size()) {
1330+
auto *proto = ProtocolsToImport[i++];
12841331
for (auto *depProto : proto->getProtocolDependencies()) {
1285-
addProtocol(depProto, /*initialComponent=*/false);
1332+
addReferencedProtocol(depProto);
12861333
}
12871334
}
12881335

1289-
// Add rewrite rules for each protocol.
1290-
for (auto *proto : Protocols) {
1336+
// If this is a rewrite system for a generic signature, add rewrite rules for
1337+
// each referenced protocol.
1338+
//
1339+
// if this is a rewrite system for a connected component of the protocol
1340+
// dependency graph, add rewrite rules for each referenced protocol not part
1341+
// of this connected component.
1342+
for (auto *proto : ProtocolsToImport) {
12911343
if (Dump) {
12921344
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
12931345
}
12941346

1295-
// Add the identity conformance rule [P].[P] => [P].
1296-
MutableTerm lhs;
1297-
lhs.add(Symbol::forProtocol(proto, Context));
1298-
lhs.add(Symbol::forProtocol(proto, Context));
1299-
1300-
MutableTerm rhs;
1301-
rhs.add(Symbol::forProtocol(proto, Context));
1302-
1303-
PermanentRules.emplace_back(lhs, rhs);
1304-
1305-
for (auto *assocType : proto->getAssociatedTypeMembers())
1306-
addAssociatedType(assocType, proto);
1307-
1308-
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
1309-
for (auto *assocType : inheritedProto->getAssociatedTypeMembers())
1310-
addAssociatedType(assocType, proto);
1311-
}
1347+
addPermanentProtocolRules(proto);
13121348

1313-
// If this protocol is part of the initial connected component, we're
1314-
// building requirement signatures for all protocols in this component,
1315-
// and so we must start with the structural requirements.
1316-
//
1317-
// Otherwise, we should either already have a requirement signature, or
1318-
// we can trigger the computation of the requirement signatures of the
1319-
// next component recursively.
1320-
if (ProtocolMap[proto]) {
1321-
for (auto req : proto->getStructuralRequirements())
1322-
addRequirement(req, proto);
1323-
1324-
for (auto req : proto->getTypeAliasRequirements())
1325-
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1326-
} else {
1327-
auto reqs = proto->getRequirementSignature();
1328-
for (auto req : reqs.getRequirements())
1329-
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1330-
for (auto alias : reqs.getTypeAliases())
1331-
addTypeAlias(alias, proto);
1332-
}
1349+
auto reqs = proto->getRequirementSignature();
1350+
for (auto req : reqs.getRequirements())
1351+
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1352+
for (auto alias : reqs.getTypeAliases())
1353+
addTypeAlias(alias, proto);
13331354

13341355
if (Dump) {
13351356
llvm::dbgs() << "}\n";

lib/AST/RequirementMachine/RequirementLowering.h

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
#include "swift/AST/Type.h"
1717
#include "llvm/ADT/ArrayRef.h"
18-
#include "llvm/ADT/DenseMap.h"
18+
#include "llvm/ADT/DenseSet.h"
1919
#include "llvm/ADT/SmallVector.h"
2020
#include <vector>
2121
#include "Diagnostics.h"
@@ -76,24 +76,22 @@ getRuleForRequirement(const Requirement &req,
7676
struct RuleBuilder {
7777
RewriteContext &Context;
7878

79-
/// The keys are the unique protocols we've added so far. The value indicates
80-
/// whether the protocol's SCC is an initial component for the rewrite system.
81-
///
82-
/// A rewrite system built from a generic signature does not have any initial
83-
/// protocols.
84-
///
85-
/// A rewrite system built from a protocol SCC has the protocols of the SCC
86-
/// itself as initial protocols.
79+
/// The transitive closure of all protocols appearing on the right hand
80+
/// side of conformance requirements.
81+
llvm::DenseSet<const ProtocolDecl *> &ReferencedProtocols;
82+
83+
/// A subset of the above in insertion order, consisting of the protocols
84+
/// whose rules we are going to import.
8785
///
88-
/// If a protocol is an initial protocol, we use its structural requirements
89-
/// instead of its requirement signature as the basis of its rewrite rules.
86+
/// If this is a rewrite system built from a generic signature, this vector
87+
/// contains all elements in the above set.
9088
///
91-
/// This is what breaks the cycle in requirement signature computation for a
92-
/// group of interdependent protocols.
93-
llvm::DenseMap<const ProtocolDecl *, bool> &ProtocolMap;
94-
95-
/// The keys of the above map in insertion order.
96-
std::vector<const ProtocolDecl *> Protocols;
89+
/// If this is a rewrite system built from a strongly connected component
90+
/// of the protocol, this vector contains all elements in the above set
91+
/// except for the protocols belonging to the component representing the
92+
/// rewrite system itself; those protocols are added directly instead of
93+
/// being imported.
94+
std::vector<const ProtocolDecl *> ProtocolsToImport;
9795

9896
/// New rules to add which will be marked 'permanent'. These are rules for
9997
/// introducing associated types, and relationships between layout,
@@ -116,18 +114,18 @@ struct RuleBuilder {
116114
bool Dump;
117115

118116
RuleBuilder(RewriteContext &ctx,
119-
llvm::DenseMap<const ProtocolDecl *, bool> &protocolMap)
120-
: Context(ctx), ProtocolMap(protocolMap),
117+
llvm::DenseSet<const ProtocolDecl *> &referencedProtocols)
118+
: Context(ctx), ReferencedProtocols(referencedProtocols),
121119
Dump(ctx.getASTContext().LangOpts.DumpRequirementMachine) {}
122120

123121
void addRequirements(ArrayRef<Requirement> requirements);
124122
void addRequirements(ArrayRef<StructuralRequirement> requirements);
125123
void addProtocols(ArrayRef<const ProtocolDecl *> proto);
126-
void addProtocol(const ProtocolDecl *proto,
127-
bool initialComponent);
124+
void addReferencedProtocol(const ProtocolDecl *proto);
128125
void collectRulesFromReferencedProtocols();
129126

130127
private:
128+
void addPermanentProtocolRules(const ProtocolDecl *proto);
131129
void addAssociatedType(const AssociatedTypeDecl *type,
132130
const ProtocolDecl *proto);
133131
void addRequirement(const Requirement &req,

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
9090

9191
// Collect the top-level requirements, and all transtively-referenced
9292
// protocol requirement signatures.
93-
RuleBuilder builder(Context, System.getProtocolMap());
93+
RuleBuilder builder(Context, System.getReferencedProtocols());
9494
builder.addRequirements(sig.getRequirements());
9595

9696
// Add the initial set of rewrite rules to the rewrite system.
@@ -134,7 +134,7 @@ RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos) {
134134
llvm::dbgs() << " {\n";
135135
}
136136

137-
RuleBuilder builder(Context, System.getProtocolMap());
137+
RuleBuilder builder(Context, System.getReferencedProtocols());
138138
builder.addProtocols(protos);
139139

140140
// Add the initial set of rewrite rules to the rewrite system.
@@ -181,7 +181,7 @@ RequirementMachine::initWithWrittenRequirements(
181181

182182
// Collect the top-level requirements, and all transtively-referenced
183183
// protocol requirement signatures.
184-
RuleBuilder builder(Context, System.getProtocolMap());
184+
RuleBuilder builder(Context, System.getReferencedProtocols());
185185
builder.addRequirements(requirements);
186186

187187
// Add the initial set of rewrite rules to the rewrite system.

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,15 @@ class RewriteSystem final {
7676
/// type is an index into the Rules array defined above.
7777
Trie<unsigned, MatchKind::Shortest> Trie;
7878

79-
/// The set of protocols known to this rewrite system. The boolean associated
80-
/// with each key is true if the protocol is part of the 'Protos' set above,
81-
/// otherwies it is false.
79+
/// The set of protocols known to this rewrite system.
8280
///
83-
/// See RuleBuilder::ProtocolMap for a more complete explanation. For the most
84-
/// part, this is only used while building the rewrite system, but conditional
85-
/// requirement inference forces us to be able to add new protocols to the
86-
/// rewrite system after the fact, so this little bit of RuleBuilder state
87-
/// outlives the initialization phase.
88-
llvm::DenseMap<const ProtocolDecl *, bool> ProtocolMap;
81+
/// See RuleBuilder::ReferencedProtocols for a more complete explanation.
82+
///
83+
/// For the most part, this is only used while building the rewrite system,
84+
/// but conditional requirement inference forces us to be able to add new
85+
/// protocols to the rewrite system after the fact, so this little bit of
86+
/// RuleBuilder state outlives the initialization phase.
87+
llvm::DenseSet<const ProtocolDecl *> ReferencedProtocols;
8988

9089
DebugOptions Debug;
9190

@@ -117,8 +116,8 @@ class RewriteSystem final {
117116
/// Return the rewrite context used for allocating memory.
118117
RewriteContext &getRewriteContext() const { return Context; }
119118

120-
llvm::DenseMap<const ProtocolDecl *, bool> &getProtocolMap() {
121-
return ProtocolMap;
119+
llvm::DenseSet<const ProtocolDecl *> &getReferencedProtocols() {
120+
return ReferencedProtocols;
122121
}
123122

124123
DebugOptions getDebugOptions() const { return Debug; }
@@ -133,7 +132,7 @@ class RewriteSystem final {
133132
}
134133

135134
bool isKnownProtocol(const ProtocolDecl *proto) const {
136-
return ProtocolMap.find(proto) != ProtocolMap.end();
135+
return ReferencedProtocols.count(proto) > 0;
137136
}
138137

139138
unsigned getRuleID(const Rule &rule) const {

0 commit comments

Comments
 (0)