Skip to content

Commit 652de97

Browse files
committed
RequirementMachine: Introduce RuleBuilder::initWithConditionalRequirements()
1 parent fb487f8 commit 652de97

File tree

2 files changed

+55
-8
lines changed

2 files changed

+55
-8
lines changed

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ void RuleBuilder::initWithGenericSignatureRequirements(
10671067

10681068
// Add rewrite rules for all top-level requirements.
10691069
for (const auto &req : requirements)
1070-
addRequirement(req, /*proto=*/nullptr, /*requirementID=*/None);
1070+
addRequirement(req, /*proto=*/nullptr);
10711071
}
10721072

10731073
/// For building a rewrite system for a generic signature from user-written
@@ -1119,7 +1119,7 @@ void RuleBuilder::initWithProtocolSignatureRequirements(
11191119

11201120
auto reqs = proto->getRequirementSignature();
11211121
for (auto req : reqs.getRequirements())
1122-
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1122+
addRequirement(req.getCanonical(), proto);
11231123
for (auto alias : reqs.getTypeAliases())
11241124
addTypeAlias(alias, proto);
11251125

@@ -1160,9 +1160,8 @@ void RuleBuilder::initWithProtocolWrittenRequirements(
11601160

11611161
for (auto req : proto->getStructuralRequirements())
11621162
addRequirement(req, proto);
1163-
11641163
for (auto req : proto->getTypeAliasRequirements())
1165-
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
1164+
addRequirement(req.getCanonical(), proto);
11661165

11671166
for (auto *otherProto : proto->getProtocolDependencies())
11681167
addReferencedProtocol(otherProto);
@@ -1177,6 +1176,43 @@ void RuleBuilder::initWithProtocolWrittenRequirements(
11771176
collectRulesFromReferencedProtocols();
11781177
}
11791178

1179+
/// For adding conditional conformance requirements to an existing rewrite
1180+
/// system. This might pull in additional protocols that we haven't seen
1181+
/// before.
1182+
///
1183+
/// The interface types in the requirements are converted to terms relative
1184+
/// to the given array of substitutions, using
1185+
/// RewriteContext::getRelativeTermForType().
1186+
///
1187+
/// For example, given a concrete conformance rule:
1188+
///
1189+
/// X.Y.[concrete: Array<X.Z> : Equatable]
1190+
///
1191+
/// The substitutions are {τ_0_0 := X.Z}, and the Array : Equatable conformance
1192+
/// has a conditional requirement 'τ_0_0 : Equatable', so the following
1193+
/// conformance rule will be added:
1194+
///
1195+
/// X.Z.[Equatable] => X.Z
1196+
void RuleBuilder::initWithConditionalRequirements(
1197+
ArrayRef<Requirement> requirements,
1198+
ArrayRef<Term> substitutions) {
1199+
assert(!Initialized);
1200+
Initialized = 1;
1201+
1202+
// Collect all protocols transitively referenced from these requirements.
1203+
for (auto req : requirements) {
1204+
if (req.getKind() == RequirementKind::Conformance) {
1205+
addReferencedProtocol(req.getProtocolDecl());
1206+
}
1207+
}
1208+
1209+
collectRulesFromReferencedProtocols();
1210+
1211+
// Add rewrite rules for all top-level requirements.
1212+
for (const auto &req : requirements)
1213+
addRequirement(req.getCanonical(), /*proto=*/nullptr, substitutions);
1214+
}
1215+
11801216
/// Add permanent rules for a protocol, consisting of:
11811217
///
11821218
/// - The identity conformance rule [P].[P] => [P].
@@ -1323,8 +1359,16 @@ swift::rewriting::getRuleForRequirement(const Requirement &req,
13231359
return std::make_pair(subjectTerm, constraintTerm);
13241360
}
13251361

1362+
/// Convert a requirement to a rule and add it to the builder.
1363+
///
1364+
/// The types in the requirement must be canonical.
1365+
///
1366+
/// If \p substitutions is not None, the interface types in the requirement
1367+
/// are converted to terms relative to these substitutions, using
1368+
/// RewriteContext::getRelativeTermForType().
13261369
void RuleBuilder::addRequirement(const Requirement &req,
13271370
const ProtocolDecl *proto,
1371+
Optional<ArrayRef<Term>> substitutions,
13281372
Optional<unsigned> requirementID) {
13291373
if (Dump) {
13301374
llvm::dbgs() << "+ ";
@@ -1333,8 +1377,7 @@ void RuleBuilder::addRequirement(const Requirement &req,
13331377
}
13341378

13351379
auto rule =
1336-
getRuleForRequirement(req, proto, /*substitutions=*/None,
1337-
Context);
1380+
getRuleForRequirement(req, proto, substitutions, Context);
13381381
RequirementRules.push_back(
13391382
std::make_tuple(rule.first, rule.second, requirementID));
13401383
}
@@ -1343,7 +1386,8 @@ void RuleBuilder::addRequirement(const StructuralRequirement &req,
13431386
const ProtocolDecl *proto) {
13441387
WrittenRequirements.push_back(req);
13451388
unsigned requirementID = WrittenRequirements.size() - 1;
1346-
addRequirement(req.req.getCanonical(), proto, requirementID);
1389+
addRequirement(req.req.getCanonical(), proto, /*substitutions=*/None,
1390+
requirementID);
13471391
}
13481392

13491393
/// Lowers a protocol typealias to a rewrite rule.

lib/AST/RequirementMachine/RequirementLowering.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ struct RuleBuilder {
132132
void initWithWrittenRequirements(ArrayRef<StructuralRequirement> requirements);
133133
void initWithProtocolSignatureRequirements(ArrayRef<const ProtocolDecl *> proto);
134134
void initWithProtocolWrittenRequirements(ArrayRef<const ProtocolDecl *> proto);
135+
void initWithConditionalRequirements(ArrayRef<Requirement> requirements,
136+
ArrayRef<Term> substitutions);
135137
void addReferencedProtocol(const ProtocolDecl *proto);
136138
void collectRulesFromReferencedProtocols();
137139

@@ -141,7 +143,8 @@ struct RuleBuilder {
141143
const ProtocolDecl *proto);
142144
void addRequirement(const Requirement &req,
143145
const ProtocolDecl *proto,
144-
Optional<unsigned> requirementID);
146+
Optional<ArrayRef<Term>> substitutions=None,
147+
Optional<unsigned> requirementID=None);
145148
void addRequirement(const StructuralRequirement &req,
146149
const ProtocolDecl *proto);
147150
void addTypeAlias(const ProtocolTypeAlias &alias,

0 commit comments

Comments
 (0)