33
33
#include " swift/AST/TypeMatcher.h"
34
34
#include " swift/AST/TypeRepr.h"
35
35
#include " llvm/ADT/SmallVector.h"
36
+ #include " llvm/ADT/SetVector.h"
37
+ #include " RequirementMachine.h"
36
38
#include " RewriteContext.h"
37
39
#include " RewriteSystem.h"
38
40
#include " Symbol.h"
@@ -1013,7 +1015,7 @@ ArrayRef<ProtocolDecl *>
1013
1015
ProtocolDependenciesRequest::evaluate (Evaluator &evaluator,
1014
1016
ProtocolDecl *proto) const {
1015
1017
auto &ctx = proto->getASTContext ();
1016
- SmallVector <ProtocolDecl *, 4 > result;
1018
+ SmallSetVector <ProtocolDecl *, 4 > result;
1017
1019
1018
1020
// If we have a serialized requirement signature, deserialize it and
1019
1021
// look at conformance requirements.
@@ -1025,7 +1027,7 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
1025
1027
== RequirementMachineMode::Disabled)) {
1026
1028
for (auto req : proto->getRequirementSignature ().getRequirements ()) {
1027
1029
if (req.getKind () == RequirementKind::Conformance) {
1028
- result.push_back (req.getProtocolDecl ());
1030
+ result.insert (req.getProtocolDecl ());
1029
1031
}
1030
1032
}
1031
1033
@@ -1037,7 +1039,7 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
1037
1039
// signature. Look at the structural requirements instead.
1038
1040
for (auto req : proto->getStructuralRequirements ()) {
1039
1041
if (req.req .getKind () == RequirementKind::Conformance)
1040
- result.push_back (req.req .getProtocolDecl ());
1042
+ result.insert (req.req .getProtocolDecl ());
1041
1043
}
1042
1044
1043
1045
return ctx.AllocateCopy (result);
@@ -1047,11 +1049,17 @@ ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
1047
1049
// Building rewrite rules from desugared requirements.
1048
1050
//
1049
1051
1050
- void RuleBuilder::addRequirements (ArrayRef<Requirement> requirements) {
1052
+ // / For building a rewrite system for a generic signature from canonical
1053
+ // / requirements.
1054
+ void RuleBuilder::initWithGenericSignatureRequirements (
1055
+ ArrayRef<Requirement> requirements) {
1056
+ assert (!Initialized);
1057
+ Initialized = 1 ;
1058
+
1051
1059
// Collect all protocols transitively referenced from these requirements.
1052
1060
for (auto req : requirements) {
1053
1061
if (req.getKind () == RequirementKind::Conformance) {
1054
- addProtocol (req.getProtocolDecl (), /* initialComponent= */ false );
1062
+ addReferencedProtocol (req.getProtocolDecl ());
1055
1063
}
1056
1064
}
1057
1065
@@ -1062,11 +1070,17 @@ void RuleBuilder::addRequirements(ArrayRef<Requirement> requirements) {
1062
1070
addRequirement (req, /* proto=*/ nullptr , /* requirementID=*/ None);
1063
1071
}
1064
1072
1065
- void RuleBuilder::addRequirements (ArrayRef<StructuralRequirement> requirements) {
1073
+ // / For building a rewrite system for a generic signature from user-written
1074
+ // / requirements.
1075
+ void RuleBuilder::initWithWrittenRequirements (
1076
+ ArrayRef<StructuralRequirement> requirements) {
1077
+ assert (!Initialized);
1078
+ Initialized = 1 ;
1079
+
1066
1080
// Collect all protocols transitively referenced from these requirements.
1067
1081
for (auto req : requirements) {
1068
1082
if (req.req .getKind () == RequirementKind::Conformance) {
1069
- addProtocol (req.req .getProtocolDecl (), /* initialComponent= */ false );
1083
+ addReferencedProtocol (req.req .getProtocolDecl ());
1070
1084
}
1071
1085
}
1072
1086
@@ -1077,16 +1091,117 @@ void RuleBuilder::addRequirements(ArrayRef<StructuralRequirement> requirements)
1077
1091
addRequirement (req, /* proto=*/ nullptr );
1078
1092
}
1079
1093
1080
- void RuleBuilder::addProtocols (ArrayRef<const ProtocolDecl *> protos) {
1094
+ // / For building a rewrite system for a protocol connected component from
1095
+ // / a previously-built requirement signature.
1096
+ // /
1097
+ // / Will trigger requirement signature computation if we haven't built
1098
+ // / requirement signatures for this connected component yet, in which case we
1099
+ // / will recursively end up building another rewrite system for this component
1100
+ // / using initWithProtocolWrittenRequirements().
1101
+ void RuleBuilder::initWithProtocolSignatureRequirements (
1102
+ ArrayRef<const ProtocolDecl *> protos) {
1103
+ assert (!Initialized);
1104
+ Initialized = 1 ;
1105
+
1106
+ // Add all protocols to the referenced set, so that subsequent calls
1107
+ // to addReferencedProtocol() with one of these protocols don't add
1108
+ // them to the import list.
1109
+ for (auto *proto : protos) {
1110
+ ReferencedProtocols.insert (proto);
1111
+ }
1112
+
1113
+ for (auto *proto : protos) {
1114
+ if (Dump) {
1115
+ llvm::dbgs () << " protocol " << proto->getName () << " {\n " ;
1116
+ }
1117
+
1118
+ addPermanentProtocolRules (proto);
1119
+
1120
+ auto reqs = proto->getRequirementSignature ();
1121
+ for (auto req : reqs.getRequirements ())
1122
+ addRequirement (req.getCanonical (), proto, /* requirementID=*/ None);
1123
+ for (auto alias : reqs.getTypeAliases ())
1124
+ addTypeAlias (alias, proto);
1125
+
1126
+ for (auto *otherProto : proto->getProtocolDependencies ())
1127
+ addReferencedProtocol (otherProto);
1128
+
1129
+ if (Dump) {
1130
+ llvm::dbgs () << " }\n " ;
1131
+ }
1132
+ }
1133
+
1081
1134
// Collect all protocols transitively referenced from this connected component
1082
1135
// of the protocol dependency graph.
1083
- for (auto proto : protos) {
1084
- addProtocol (proto, /* initialComponent=*/ true );
1136
+ collectRulesFromReferencedProtocols ();
1137
+ }
1138
+
1139
+ // / For building a rewrite system for a protocol connected component from
1140
+ // / user-written requirements. Used when actually building requirement
1141
+ // / signatures.
1142
+ void RuleBuilder::initWithProtocolWrittenRequirements (
1143
+ ArrayRef<const ProtocolDecl *> protos) {
1144
+ assert (!Initialized);
1145
+ Initialized = 1 ;
1146
+
1147
+ // Add all protocols to the referenced set, so that subsequent calls
1148
+ // to addReferencedProtocol() with one of these protocols don't add
1149
+ // them to the import list.
1150
+ for (auto *proto : protos) {
1151
+ ReferencedProtocols.insert (proto);
1152
+ }
1153
+
1154
+ for (auto *proto : protos) {
1155
+ if (Dump) {
1156
+ llvm::dbgs () << " protocol " << proto->getName () << " {\n " ;
1157
+ }
1158
+
1159
+ addPermanentProtocolRules (proto);
1160
+
1161
+ for (auto req : proto->getStructuralRequirements ())
1162
+ addRequirement (req, proto);
1163
+
1164
+ for (auto req : proto->getTypeAliasRequirements ())
1165
+ addRequirement (req.getCanonical (), proto, /* requirementID=*/ None);
1166
+
1167
+ for (auto *otherProto : proto->getProtocolDependencies ())
1168
+ addReferencedProtocol (otherProto);
1169
+
1170
+ if (Dump) {
1171
+ llvm::dbgs () << " }\n " ;
1172
+ }
1085
1173
}
1086
1174
1175
+ // Collect all protocols transitively referenced from this connected component
1176
+ // of the protocol dependency graph.
1087
1177
collectRulesFromReferencedProtocols ();
1088
1178
}
1089
1179
1180
+ // / Add permanent rules for a protocol, consisting of:
1181
+ // /
1182
+ // / - The identity conformance rule [P].[P] => [P].
1183
+ // / - An associated type introduction rule for each associated type.
1184
+ // / - An inherited associated type introduction rule for each associated
1185
+ // / type of each inherited protocol.
1186
+ void RuleBuilder::addPermanentProtocolRules (const ProtocolDecl *proto) {
1187
+ MutableTerm lhs;
1188
+ lhs.add (Symbol::forProtocol (proto, Context));
1189
+ lhs.add (Symbol::forProtocol (proto, Context));
1190
+
1191
+ MutableTerm rhs;
1192
+ rhs.add (Symbol::forProtocol (proto, Context));
1193
+
1194
+ PermanentRules.emplace_back (lhs, rhs);
1195
+
1196
+ for (auto *assocType : proto->getAssociatedTypeMembers ())
1197
+ addAssociatedType (assocType, proto);
1198
+
1199
+ for (auto *inheritedProto : Context.getInheritedProtocols (proto)) {
1200
+ for (auto *assocType : inheritedProto->getAssociatedTypeMembers ())
1201
+ addAssociatedType (assocType, proto);
1202
+ }
1203
+ }
1204
+
1090
1205
// / For an associated type T in a protocol P, we add a rewrite rule:
1091
1206
// /
1092
1207
// / [P].T => [P:T]
@@ -1264,75 +1379,58 @@ void RuleBuilder::addTypeAlias(const ProtocolTypeAlias &alias,
1264
1379
/* requirementID=*/ None);
1265
1380
}
1266
1381
1267
- // / Record information about a protocol if we have no seen it yet.
1268
- void RuleBuilder::addProtocol (const ProtocolDecl *proto,
1269
- bool initialComponent) {
1270
- if (ProtocolMap.count (proto) > 0 )
1271
- return ;
1272
-
1273
- ProtocolMap[proto] = initialComponent;
1274
- Protocols.push_back (proto);
1382
+ // / If we haven't seen this protocol yet, save it for later so that we can
1383
+ // / import the rewrite rules from its connected component.
1384
+ void RuleBuilder::addReferencedProtocol (const ProtocolDecl *proto) {
1385
+ if (ReferencedProtocols.insert (proto).second )
1386
+ ProtocolsToImport.push_back (proto);
1275
1387
}
1276
1388
1277
1389
// / Compute the transitive closure of the set of all protocols referenced from
1278
1390
// / the right hand sides of conformance requirements, and convert their
1279
1391
// / requirements to rewrite rules.
1280
1392
void RuleBuilder::collectRulesFromReferencedProtocols () {
1393
+ // Compute the transitive closure.
1281
1394
unsigned i = 0 ;
1282
- while (i < Protocols .size ()) {
1283
- auto *proto = Protocols [i++];
1395
+ while (i < ProtocolsToImport .size ()) {
1396
+ auto *proto = ProtocolsToImport [i++];
1284
1397
for (auto *depProto : proto->getProtocolDependencies ()) {
1285
- addProtocol (depProto, /* initialComponent= */ false );
1398
+ addReferencedProtocol (depProto);
1286
1399
}
1287
1400
}
1288
1401
1289
- // Add rewrite rules for each protocol.
1290
- for (auto *proto : Protocols) {
1402
+ // If this is a rewrite system for a generic signature, add rewrite rules for
1403
+ // each referenced protocol.
1404
+ //
1405
+ // if this is a rewrite system for a connected component of the protocol
1406
+ // dependency graph, add rewrite rules for each referenced protocol not part
1407
+ // of this connected component.
1408
+
1409
+ // First, collect all unique requirement machines, one for each connected
1410
+ // component of each referenced protocol.
1411
+ llvm::DenseSet<RequirementMachine *> machines;
1412
+
1413
+ // Now visit each subordinate requirement machine pull in its rules.
1414
+ for (auto *proto : ProtocolsToImport) {
1415
+ // This will trigger requirement signature computation for this protocol,
1416
+ // if neccessary, which will cause us to re-enter into a new RuleBuilder
1417
+ // instace under RuleBuilder::initWithProtocolWrittenRequirements().
1291
1418
if (Dump) {
1292
- llvm::dbgs () << " protocol " << proto->getName () << " {\n " ;
1419
+ llvm::dbgs () << " importing protocol " << proto->getName () << " {\n " ;
1293
1420
}
1294
1421
1295
- // Add the identity conformance rule [P].[P] => [P].
1296
- MutableTerm lhs;
1297
- lhs.add (Symbol::forProtocol (proto, Context));
1298
- lhs.add (Symbol::forProtocol (proto, Context));
1299
-
1300
- MutableTerm rhs;
1301
- rhs.add (Symbol::forProtocol (proto, Context));
1302
-
1303
- PermanentRules.emplace_back (lhs, rhs);
1304
-
1305
- for (auto *assocType : proto->getAssociatedTypeMembers ())
1306
- addAssociatedType (assocType, proto);
1307
-
1308
- for (auto *inheritedProto : Context.getInheritedProtocols (proto)) {
1309
- for (auto *assocType : inheritedProto->getAssociatedTypeMembers ())
1310
- addAssociatedType (assocType, proto);
1311
- }
1312
-
1313
- // If this protocol is part of the initial connected component, we're
1314
- // building requirement signatures for all protocols in this component,
1315
- // and so we must start with the structural requirements.
1316
- //
1317
- // Otherwise, we should either already have a requirement signature, or
1318
- // we can trigger the computation of the requirement signatures of the
1319
- // next component recursively.
1320
- if (ProtocolMap[proto]) {
1321
- for (auto req : proto->getStructuralRequirements ())
1322
- addRequirement (req, proto);
1323
-
1324
- for (auto req : proto->getTypeAliasRequirements ())
1325
- addRequirement (req.getCanonical (), proto, /* requirementID=*/ None);
1326
- } else {
1327
- auto reqs = proto->getRequirementSignature ();
1328
- for (auto req : reqs.getRequirements ())
1329
- addRequirement (req.getCanonical (), proto, /* requirementID=*/ None);
1330
- for (auto alias : reqs.getTypeAliases ())
1331
- addTypeAlias (alias, proto);
1422
+ auto *machine = Context.getRequirementMachine (proto);
1423
+ if (!machines.insert (machine).second ) {
1424
+ // We've already seen this connected component.
1425
+ continue ;
1332
1426
}
1333
1427
1334
- if (Dump) {
1335
- llvm::dbgs () << " }\n " ;
1336
- }
1428
+ // We grab the machine's local rules, not *all* of its rules, to avoid
1429
+ // duplicates in case multiple machines share a dependency on a downstream
1430
+ // protocol connected component.
1431
+ auto localRules = machine->getLocalRules ();
1432
+ ImportedRules.insert (ImportedRules.end (),
1433
+ localRules.begin (),
1434
+ localRules.end ());
1337
1435
}
1338
1436
}
0 commit comments