Skip to content

Commit 8c27759

Browse files
authored
Merge pull request #41817 from slavapestov/rqm-rule-sharing
RequirementMachine: Rule sharing
2 parents fcde683 + d7f8445 commit 8c27759

19 files changed

+507
-162
lines changed

lib/AST/RequirementMachine/ConcreteTypeWitness.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,8 @@ void PropertyMap::inferConditionalRequirements(
562562
llvm::dbgs() << "@@@ Unknown protocol: "<< proto->getName() << "\n";
563563
}
564564

565-
RuleBuilder builder(Context, System.getProtocolMap());
566-
builder.addProtocol(proto, /*initialComponent=*/false);
565+
RuleBuilder builder(Context, System.getReferencedProtocols());
566+
builder.addReferencedProtocol(proto);
567567
builder.collectRulesFromReferencedProtocols();
568568

569569
for (const auto &rule : builder.PermanentRules)

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,8 @@ RewriteSystem::getMinimizedProtocolRules() const {
682682
assert(!Protos.empty());
683683

684684
llvm::DenseMap<const ProtocolDecl *, MinimizedProtocolRules> rules;
685-
for (unsigned ruleID : indices(Rules)) {
685+
for (unsigned ruleID = FirstLocalRule, e = Rules.size();
686+
ruleID < e; ++ruleID) {
686687
const auto &rule = getRule(ruleID);
687688

688689
if (rule.isPermanent() ||
@@ -713,7 +714,8 @@ RewriteSystem::getMinimizedGenericSignatureRules() const {
713714
assert(Protos.empty());
714715

715716
std::vector<unsigned> rules;
716-
for (unsigned ruleID : indices(Rules)) {
717+
for (unsigned ruleID = FirstLocalRule, e = Rules.size();
718+
ruleID < e; ++ruleID) {
717719
const auto &rule = getRule(ruleID);
718720

719721
if (rule.isPermanent() ||
@@ -768,7 +770,8 @@ void RewriteSystem::verifyMinimizedRules(
768770
const llvm::DenseSet<unsigned> &redundantConformances) const {
769771
unsigned redundantRuleCount = 0;
770772

771-
for (unsigned ruleID : indices(Rules)) {
773+
for (unsigned ruleID = FirstLocalRule, e = Rules.size();
774+
ruleID < e; ++ruleID) {
772775
const auto &rule = getRule(ruleID);
773776

774777
// Ignore the rewrite rule if it is not part of our minimization domain.

lib/AST/RequirementMachine/KnuthBendix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ RewriteSystem::computeConfluentCompletion(unsigned maxRuleCount,
301301
ruleCount = Rules.size();
302302

303303
// For every rule, looking for other rules that overlap with this rule.
304-
for (unsigned i = 0, e = Rules.size(); i < e; ++i) {
304+
for (unsigned i = FirstLocalRule, e = Rules.size(); i < e; ++i) {
305305
const auto &lhs = getRule(i);
306306
if (lhs.isLHSSimplified() ||
307307
lhs.isRHSSimplified() ||

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 162 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
#include "swift/AST/TypeMatcher.h"
3434
#include "swift/AST/TypeRepr.h"
3535
#include "llvm/ADT/SmallVector.h"
36+
#include "llvm/ADT/SetVector.h"
37+
#include "RequirementMachine.h"
3638
#include "RewriteContext.h"
3739
#include "RewriteSystem.h"
3840
#include "Symbol.h"
@@ -1013,7 +1015,7 @@ ArrayRef<ProtocolDecl *>
10131015
ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
10141016
ProtocolDecl *proto) const {
10151017
auto &ctx = proto->getASTContext();
1016-
SmallVector<ProtocolDecl *, 4> result;
1018+
SmallSetVector<ProtocolDecl *, 4> result;
10171019

10181020
// If we have a serialized requirement signature, deserialize it and
10191021
// look at conformance requirements.
@@ -1025,7 +1027,7 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
10251027
== RequirementMachineMode::Disabled)) {
10261028
for (auto req : proto->getRequirementSignature().getRequirements()) {
10271029
if (req.getKind() == RequirementKind::Conformance) {
1028-
result.push_back(req.getProtocolDecl());
1030+
result.insert(req.getProtocolDecl());
10291031
}
10301032
}
10311033

@@ -1037,7 +1039,7 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
10371039
// signature. Look at the structural requirements instead.
10381040
for (auto req : proto->getStructuralRequirements()) {
10391041
if (req.req.getKind() == RequirementKind::Conformance)
1040-
result.push_back(req.req.getProtocolDecl());
1042+
result.insert(req.req.getProtocolDecl());
10411043
}
10421044

10431045
return ctx.AllocateCopy(result);
@@ -1047,11 +1049,17 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
10471049
// Building rewrite rules from desugared requirements.
10481050
//
10491051

1050-
void RuleBuilder::addRequirements(ArrayRef<Requirement> requirements) {
1052+
/// For building a rewrite system for a generic signature from canonical
1053+
/// requirements.
1054+
void RuleBuilder::initWithGenericSignatureRequirements(
1055+
ArrayRef<Requirement> requirements) {
1056+
assert(!Initialized);
1057+
Initialized = 1;
1058+
10511059
// Collect all protocols transitively referenced from these requirements.
10521060
for (auto req : requirements) {
10531061
if (req.getKind() == RequirementKind::Conformance) {
1054-
addProtocol(req.getProtocolDecl(), /*initialComponent=*/false);
1062+
addReferencedProtocol(req.getProtocolDecl());
10551063
}
10561064
}
10571065

@@ -1062,11 +1070,17 @@ void RuleBuilder::addRequirements(ArrayRef<Requirement> requirements) {
10621070
addRequirement(req, /*proto=*/nullptr, /*requirementID=*/None);
10631071
}
10641072

1065-
void RuleBuilder::addRequirements(ArrayRef<StructuralRequirement> requirements) {
1073+
/// For building a rewrite system for a generic signature from user-written
1074+
/// requirements.
1075+
void RuleBuilder::initWithWrittenRequirements(
1076+
ArrayRef<StructuralRequirement> requirements) {
1077+
assert(!Initialized);
1078+
Initialized = 1;
1079+
10661080
// Collect all protocols transitively referenced from these requirements.
10671081
for (auto req : requirements) {
10681082
if (req.req.getKind() == RequirementKind::Conformance) {
1069-
addProtocol(req.req.getProtocolDecl(), /*initialComponent=*/false);
1083+
addReferencedProtocol(req.req.getProtocolDecl());
10701084
}
10711085
}
10721086

@@ -1077,16 +1091,117 @@ void RuleBuilder::addRequirements(ArrayRef<StructuralRequirement> requirements)
10771091
addRequirement(req, /*proto=*/nullptr);
10781092
}
10791093

1080-
void RuleBuilder::addProtocols(ArrayRef<const ProtocolDecl *> protos) {
1094+
/// For building a rewrite system for a protocol connected component from
1095+
/// a previously-built requirement signature.
1096+
///
1097+
/// Will trigger requirement signature computation if we haven't built
1098+
/// requirement signatures for this connected component yet, in which case we
1099+
/// will recursively end up building another rewrite system for this component
1100+
/// using initWithProtocolWrittenRequirements().
1101+
void RuleBuilder::initWithProtocolSignatureRequirements(
1102+
ArrayRef<const ProtocolDecl *> protos) {
1103+
assert(!Initialized);
1104+
Initialized = 1;
1105+
1106+
// Add all protocols to the referenced set, so that subsequent calls
1107+
// to addReferencedProtocol() with one of these protocols don't add
1108+
// them to the import list.
1109+
for (auto *proto : protos) {
1110+
ReferencedProtocols.insert(proto);
1111+
}
1112+
1113+
for (auto *proto : protos) {
1114+
if (Dump) {
1115+
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
1116+
}
1117+
1118+
addPermanentProtocolRules(proto);
1119+
1120+
auto reqs = proto->getRequirementSignature();
1121+
for (auto req : reqs.getRequirements())
1122+
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1123+
for (auto alias : reqs.getTypeAliases())
1124+
addTypeAlias(alias, proto);
1125+
1126+
for (auto *otherProto : proto->getProtocolDependencies())
1127+
addReferencedProtocol(otherProto);
1128+
1129+
if (Dump) {
1130+
llvm::dbgs() << "}\n";
1131+
}
1132+
}
1133+
10811134
// Collect all protocols transitively referenced from this connected component
10821135
// of the protocol dependency graph.
1083-
for (auto proto : protos) {
1084-
addProtocol(proto, /*initialComponent=*/true);
1136+
collectRulesFromReferencedProtocols();
1137+
}
1138+
1139+
/// For building a rewrite system for a protocol connected component from
1140+
/// user-written requirements. Used when actually building requirement
1141+
/// signatures.
1142+
void RuleBuilder::initWithProtocolWrittenRequirements(
1143+
ArrayRef<const ProtocolDecl *> protos) {
1144+
assert(!Initialized);
1145+
Initialized = 1;
1146+
1147+
// Add all protocols to the referenced set, so that subsequent calls
1148+
// to addReferencedProtocol() with one of these protocols don't add
1149+
// them to the import list.
1150+
for (auto *proto : protos) {
1151+
ReferencedProtocols.insert(proto);
1152+
}
1153+
1154+
for (auto *proto : protos) {
1155+
if (Dump) {
1156+
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
1157+
}
1158+
1159+
addPermanentProtocolRules(proto);
1160+
1161+
for (auto req : proto->getStructuralRequirements())
1162+
addRequirement(req, proto);
1163+
1164+
for (auto req : proto->getTypeAliasRequirements())
1165+
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1166+
1167+
for (auto *otherProto : proto->getProtocolDependencies())
1168+
addReferencedProtocol(otherProto);
1169+
1170+
if (Dump) {
1171+
llvm::dbgs() << "}\n";
1172+
}
10851173
}
10861174

1175+
// Collect all protocols transitively referenced from this connected component
1176+
// of the protocol dependency graph.
10871177
collectRulesFromReferencedProtocols();
10881178
}
10891179

1180+
/// Add permanent rules for a protocol, consisting of:
1181+
///
1182+
/// - The identity conformance rule [P].[P] => [P].
1183+
/// - An associated type introduction rule for each associated type.
1184+
/// - An inherited associated type introduction rule for each associated
1185+
/// type of each inherited protocol.
1186+
void RuleBuilder::addPermanentProtocolRules(const ProtocolDecl *proto) {
1187+
MutableTerm lhs;
1188+
lhs.add(Symbol::forProtocol(proto, Context));
1189+
lhs.add(Symbol::forProtocol(proto, Context));
1190+
1191+
MutableTerm rhs;
1192+
rhs.add(Symbol::forProtocol(proto, Context));
1193+
1194+
PermanentRules.emplace_back(lhs, rhs);
1195+
1196+
for (auto *assocType : proto->getAssociatedTypeMembers())
1197+
addAssociatedType(assocType, proto);
1198+
1199+
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
1200+
for (auto *assocType : inheritedProto->getAssociatedTypeMembers())
1201+
addAssociatedType(assocType, proto);
1202+
}
1203+
}
1204+
10901205
/// For an associated type T in a protocol P, we add a rewrite rule:
10911206
///
10921207
/// [P].T => [P:T]
@@ -1264,75 +1379,58 @@ void RuleBuilder::addTypeAlias(const ProtocolTypeAlias &alias,
12641379
/*requirementID=*/None);
12651380
}
12661381

1267-
/// Record information about a protocol if we have no seen it yet.
1268-
void RuleBuilder::addProtocol(const ProtocolDecl *proto,
1269-
bool initialComponent) {
1270-
if (ProtocolMap.count(proto) > 0)
1271-
return;
1272-
1273-
ProtocolMap[proto] = initialComponent;
1274-
Protocols.push_back(proto);
1382+
/// If we haven't seen this protocol yet, save it for later so that we can
1383+
/// import the rewrite rules from its connected component.
1384+
void RuleBuilder::addReferencedProtocol(const ProtocolDecl *proto) {
1385+
if (ReferencedProtocols.insert(proto).second)
1386+
ProtocolsToImport.push_back(proto);
12751387
}
12761388

12771389
/// Compute the transitive closure of the set of all protocols referenced from
12781390
/// the right hand sides of conformance requirements, and convert their
12791391
/// requirements to rewrite rules.
12801392
void RuleBuilder::collectRulesFromReferencedProtocols() {
1393+
// Compute the transitive closure.
12811394
unsigned i = 0;
1282-
while (i < Protocols.size()) {
1283-
auto *proto = Protocols[i++];
1395+
while (i < ProtocolsToImport.size()) {
1396+
auto *proto = ProtocolsToImport[i++];
12841397
for (auto *depProto : proto->getProtocolDependencies()) {
1285-
addProtocol(depProto, /*initialComponent=*/false);
1398+
addReferencedProtocol(depProto);
12861399
}
12871400
}
12881401

1289-
// Add rewrite rules for each protocol.
1290-
for (auto *proto : Protocols) {
1402+
// If this is a rewrite system for a generic signature, add rewrite rules for
1403+
// each referenced protocol.
1404+
//
1405+
// if this is a rewrite system for a connected component of the protocol
1406+
// dependency graph, add rewrite rules for each referenced protocol not part
1407+
// of this connected component.
1408+
1409+
// First, collect all unique requirement machines, one for each connected
1410+
// component of each referenced protocol.
1411+
llvm::DenseSet<RequirementMachine *> machines;
1412+
1413+
// Now visit each subordinate requirement machine pull in its rules.
1414+
for (auto *proto : ProtocolsToImport) {
1415+
// This will trigger requirement signature computation for this protocol,
1416+
// if neccessary, which will cause us to re-enter into a new RuleBuilder
1417+
// instace under RuleBuilder::initWithProtocolWrittenRequirements().
12911418
if (Dump) {
1292-
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
1419+
llvm::dbgs() << "importing protocol " << proto->getName() << " {\n";
12931420
}
12941421

1295-
// Add the identity conformance rule [P].[P] => [P].
1296-
MutableTerm lhs;
1297-
lhs.add(Symbol::forProtocol(proto, Context));
1298-
lhs.add(Symbol::forProtocol(proto, Context));
1299-
1300-
MutableTerm rhs;
1301-
rhs.add(Symbol::forProtocol(proto, Context));
1302-
1303-
PermanentRules.emplace_back(lhs, rhs);
1304-
1305-
for (auto *assocType : proto->getAssociatedTypeMembers())
1306-
addAssociatedType(assocType, proto);
1307-
1308-
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
1309-
for (auto *assocType : inheritedProto->getAssociatedTypeMembers())
1310-
addAssociatedType(assocType, proto);
1311-
}
1312-
1313-
// If this protocol is part of the initial connected component, we're
1314-
// building requirement signatures for all protocols in this component,
1315-
// and so we must start with the structural requirements.
1316-
//
1317-
// Otherwise, we should either already have a requirement signature, or
1318-
// we can trigger the computation of the requirement signatures of the
1319-
// next component recursively.
1320-
if (ProtocolMap[proto]) {
1321-
for (auto req : proto->getStructuralRequirements())
1322-
addRequirement(req, proto);
1323-
1324-
for (auto req : proto->getTypeAliasRequirements())
1325-
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1326-
} else {
1327-
auto reqs = proto->getRequirementSignature();
1328-
for (auto req : reqs.getRequirements())
1329-
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1330-
for (auto alias : reqs.getTypeAliases())
1331-
addTypeAlias(alias, proto);
1422+
auto *machine = Context.getRequirementMachine(proto);
1423+
if (!machines.insert(machine).second) {
1424+
// We've already seen this connected component.
1425+
continue;
13321426
}
13331427

1334-
if (Dump) {
1335-
llvm::dbgs() << "}\n";
1336-
}
1428+
// We grab the machine's local rules, not *all* of its rules, to avoid
1429+
// duplicates in case multiple machines share a dependency on a downstream
1430+
// protocol connected component.
1431+
auto localRules = machine->getLocalRules();
1432+
ImportedRules.insert(ImportedRules.end(),
1433+
localRules.begin(),
1434+
localRules.end());
13371435
}
13381436
}

0 commit comments

Comments
 (0)