Skip to content

Commit 952dafa

Browse files
committed
RequirementMachine: Preliminary refactoring in preparation for computing top-level generic signatures
1 parent 796123e commit 952dafa

File tree

7 files changed

+111
-33
lines changed

7 files changed

+111
-33
lines changed

include/swift/AST/TypeCheckRequests.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,9 @@ class ProtocolDependenciesRequest :
406406
bool isCached() const { return true; }
407407
};
408408

409-
/// Compute the requirements that describe a protocol using the
410-
/// RequirementMachine.
409+
/// Compute a protocol's requirement signature using the RequirementMachine.
410+
/// This is temporary; once the GenericSignatureBuilder goes away this will
411+
/// be folded into RequirementSignatureRequest.
411412
class RequirementSignatureRequestRQM :
412413
public SimpleRequest<RequirementSignatureRequestRQM,
413414
ArrayRef<Requirement>(ProtocolDecl *),

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ SWIFT_REQUEST(TypeChecker, AbstractGenericSignatureRequest,
2020
SmallVector<GenericTypeParamType *, 2>,
2121
SmallVector<Requirement, 2>),
2222
Cached, NoLocationInfo)
23+
SWIFT_REQUEST(TypeChecker, AbstractGenericSignatureRequestRQM,
24+
GenericSignatureWithError (const GenericSignatureImpl *,
25+
SmallVector<GenericTypeParamType *, 2>,
26+
SmallVector<Requirement, 2>),
27+
Cached, NoLocationInfo)
2328
SWIFT_REQUEST(TypeChecker, ApplyAccessNoteRequest,
2429
evaluator::SideEffect(ValueDecl *), Cached, NoLocationInfo)
2530
SWIFT_REQUEST(TypeChecker, AttachedResultBuilderRequest,

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -779,10 +779,14 @@ void RewriteSystem::minimizeRewriteSystem() {
779779
}
780780

781781
/// Collect all non-permanent, non-redundant rules whose domain is equal to
782-
/// one of the protocols in \p proto. These rules form the requirement
783-
/// signatures of these protocols.
782+
/// one of the protocols in \p proto. In other words, the first symbol of the
783+
/// left hand side term is either a protocol symbol or associated type symbol
784+
/// whose protocol is in \p proto.
785+
///
786+
/// These rules form the requirement signatures of these protocols.
784787
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>>
785-
RewriteSystem::getMinimizedRules(ArrayRef<const ProtocolDecl *> protos) {
788+
RewriteSystem::getMinimizedProtocolRules(
789+
ArrayRef<const ProtocolDecl *> protos) const {
786790
assert(Minimized);
787791

788792
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>> rules;
@@ -806,6 +810,33 @@ RewriteSystem::getMinimizedRules(ArrayRef<const ProtocolDecl *> protos) {
806810
return rules;
807811
}
808812

813+
/// Collect all non-permanent, non-redundant rules whose left hand side
814+
/// begins with a generic parameter symbol.
815+
///
816+
/// These rules form the top-level generic signature for this rewrite system.
817+
std::vector<unsigned>
818+
RewriteSystem::getMinimizedGenericSignatureRules() const {
819+
assert(Minimized);
820+
821+
std::vector<unsigned> rules;
822+
for (unsigned ruleID : indices(Rules)) {
823+
const auto &rule = getRule(ruleID);
824+
825+
if (rule.isPermanent())
826+
continue;
827+
828+
if (rule.isRedundant())
829+
continue;
830+
831+
if (rule.getLHS()[0].getKind() != Symbol::Kind::GenericParam)
832+
continue;
833+
834+
rules.push_back(ruleID);
835+
}
836+
837+
return rules;
838+
}
839+
809840
/// Verify that each loop begins and ends at its basepoint.
810841
void RewriteSystem::verifyRewriteLoops() const {
811842
#ifndef NDEBUG

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct RewriteSystemBuilder {
6666

6767
RewriteSystemBuilder(RewriteContext &ctx, bool dump)
6868
: Context(ctx), Dump(dump) {}
69-
void addGenericSignature(CanGenericSignature sig);
69+
void addRequirements(ArrayRef<Requirement> requirements);
7070
void addProtocols(ArrayRef<const ProtocolDecl *> proto);
7171
void addProtocol(const ProtocolDecl *proto,
7272
bool initialComponent);
@@ -108,19 +108,18 @@ RewriteSystemBuilder::getConcreteSubstitutionSchema(CanType concreteType,
108108
}));
109109
}
110110

111-
void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
112-
// Collect all protocols transitively referenced from the generic signature's
113-
// requirements.
114-
for (auto req : sig.getRequirements()) {
111+
void RewriteSystemBuilder::addRequirements(ArrayRef<Requirement> requirements) {
112+
// Collect all protocols transitively referenced from these requirements.
113+
for (auto req : requirements) {
115114
if (req.getKind() == RequirementKind::Conformance) {
116115
addProtocol(req.getProtocolDecl(), /*initialComponent=*/false);
117116
}
118117
}
119118

120119
processProtocolDependencies();
121120

122-
// Add rewrite rules for all requirements in the top-level signature.
123-
for (const auto &req : sig.getRequirements())
121+
// Add rewrite rules for all top-level requirements.
122+
for (const auto &req : requirements)
124123
addRequirement(req, /*proto=*/nullptr);
125124
}
126125

@@ -350,7 +349,7 @@ void RequirementMachine::verify(const MutableTerm &term) const {
350349
// generic parameter.
351350
if (term.begin()->getKind() == Symbol::Kind::GenericParam) {
352351
auto *genericParam = term.begin()->getGenericParam();
353-
auto genericParams = Sig.getGenericParams();
352+
TypeArrayView<GenericTypeParamType> genericParams = getGenericParams();
354353
auto found = std::find(genericParams.begin(),
355354
genericParams.end(),
356355
genericParam);
@@ -428,8 +427,13 @@ void RequirementMachine::dump(llvm::raw_ostream &out) const {
428427
out << "Requirement machine for ";
429428
if (Sig)
430429
out << Sig;
431-
else {
432-
out << "[";
430+
else if (!Params.empty()) {
431+
out << "fresh signature ";
432+
for (auto paramTy : Params)
433+
out << " " << Type(paramTy);
434+
} else {
435+
assert(!Protos.empty());
436+
out << "protocols [";
433437
for (auto *proto : Protos) {
434438
out << " " << proto->getName();
435439
}
@@ -467,6 +471,8 @@ RequirementMachine::~RequirementMachine() {}
467471
/// performed on this requirement machine.
468472
void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
469473
Sig = sig;
474+
Params.append(sig.getGenericParams().begin(),
475+
sig.getGenericParams().end());
470476

471477
PrettyStackTraceGenericSignature debugStack("building rewrite system for", sig);
472478

@@ -485,7 +491,7 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
485491
// Collect the top-level requirements, and all transtively-referenced
486492
// protocol requirement signatures.
487493
RewriteSystemBuilder builder(Context, Dump);
488-
builder.addGenericSignature(sig);
494+
builder.addRequirements(sig.getRequirements());
489495

490496
// Add the initial set of rewrite rules to the rewrite system.
491497
System.initialize(/*recordLoops=*/false,

lib/AST/RequirementMachine/RequirementMachine.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class RequirementMachine final {
4747
friend class swift::rewriting::RewriteContext;
4848

4949
CanGenericSignature Sig;
50+
SmallVector<Type, 2> Params;
5051
ArrayRef<const ProtocolDecl *> Protos;
5152

5253
RewriteContext &Context;
@@ -87,8 +88,14 @@ class RequirementMachine final {
8788

8889
MutableTerm getLongestValidPrefix(const MutableTerm &term) const;
8990

90-
std::vector<Requirement> buildRequirementSignature(
91-
ArrayRef<unsigned> rules, const ProtocolDecl *proto) const;
91+
std::vector<Requirement> buildRequirementsFromRules(
92+
ArrayRef<unsigned> rules,
93+
TypeArrayView<GenericTypeParamType> genericParams) const;
94+
95+
TypeArrayView<GenericTypeParamType> getGenericParams() const {
96+
return TypeArrayView<GenericTypeParamType>(
97+
ArrayRef<Type>(Params));
98+
}
9299

93100
public:
94101
~RequirementMachine();
@@ -116,7 +123,9 @@ class RequirementMachine final {
116123
TypeDecl *lookupNestedType(Type depType, Identifier name) const;
117124

118125
llvm::DenseMap<const ProtocolDecl *, std::vector<Requirement>>
119-
computeMinimalRequirements();
126+
computeMinimalProtocolRequirements();
127+
128+
std::vector<Requirement> computeMinimalGenericSignatureRequirements();
120129

121130
void verify(const MutableTerm &term) const;
122131
void dump(llvm::raw_ostream &out) const;

lib/AST/RequirementMachine/RequirementMachineRequests.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,16 @@ void ConnectedComponent::buildRequirements(Type subjectType,
8888

8989
} // end namespace
9090

91-
/// Convert a list of non-permanent, non-redundant rewrite rules into a minimal
92-
/// protocol requirement signature for \p proto. The requirements are sorted in
93-
/// canonical order, and same-type requirements are canonicalized.
91+
/// Convert a list of non-permanent, non-redundant rewrite rules into a list of
92+
/// requirements sorted in canonical order. The \p genericParams are used to
93+
/// produce sugared types.
9494
std::vector<Requirement>
95-
RequirementMachine::buildRequirementSignature(ArrayRef<unsigned> rules,
96-
const ProtocolDecl *proto) const {
95+
RequirementMachine::buildRequirementsFromRules(
96+
ArrayRef<unsigned> rules,
97+
TypeArrayView<GenericTypeParamType> genericParams) const {
9798
std::vector<Requirement> reqs;
9899
llvm::SmallDenseMap<TypeBase *, ConnectedComponent> sameTypeReqs;
99100

100-
auto genericParams = proto->getGenericSignature().getGenericParams();
101-
102101
// Convert a rewrite rule into a requirement.
103102
auto createRequirementFromRule = [&](const Rule &rule) {
104103
if (auto prop = rule.isPropertyRule()) {
@@ -196,23 +195,29 @@ RequirementMachine::buildRequirementSignature(ArrayRef<unsigned> rules,
196195
/// Builds the requirement signatures for each protocol in this strongly
197196
/// connected component.
198197
llvm::DenseMap<const ProtocolDecl *, std::vector<Requirement>>
199-
RequirementMachine::computeMinimalRequirements() {
200-
assert(Protos.size() > 0);
198+
RequirementMachine::computeMinimalProtocolRequirements() {
199+
assert(Protos.size() > 0 &&
200+
"Not a protocol connected component rewrite system");
201+
assert(Params.empty() &&
202+
"Not a protocol connected component rewrite system");
203+
201204
System.minimizeRewriteSystem();
202205

203206
if (Dump) {
204207
llvm::dbgs() << "Minimized rewrite system:\n";
205208
dump(llvm::dbgs());
206209
}
207210

208-
auto rules = System.getMinimizedRules(Protos);
211+
auto rules = System.getMinimizedProtocolRules(Protos);
209212

210213
// Note that we build 'result' by iterating over 'Protos' rather than
211214
// 'rules'; this is intentional, so that even if a protocol has no
212215
// rules, we still end up creating an entry for it in 'result'.
213216
llvm::DenseMap<const ProtocolDecl *, std::vector<Requirement>> result;
214-
for (const auto *proto : Protos)
215-
result[proto] = buildRequirementSignature(rules[proto], proto);
217+
for (const auto *proto : Protos) {
218+
auto genericParams = proto->getGenericSignature().getGenericParams();
219+
result[proto] = buildRequirementsFromRules(rules[proto], genericParams);
220+
}
216221

217222
return result;
218223
}
@@ -229,7 +234,7 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator,
229234
// We build requirement signatures for all protocols in a strongly connected
230235
// component at the same time.
231236
auto *machine = ctx.getOrCreateRequirementMachine(proto);
232-
auto requirements = machine->computeMinimalRequirements();
237+
auto requirements = machine->computeMinimalProtocolRequirements();
233238

234239
bool debug = machine->getDebugOptions().contains(DebugFlags::Minimization);
235240

@@ -269,3 +274,22 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator,
269274
// Return the result for the specific protocol this request was kicked off on.
270275
return result;
271276
}
277+
278+
/// Builds the top-level generic signature requirements for this rewrite system.
279+
std::vector<Requirement>
280+
RequirementMachine::computeMinimalGenericSignatureRequirements() {
281+
assert(Protos.empty() &&
282+
"Not a top-level generic signature rewrite system");
283+
assert(!Params.empty() &&
284+
"Not a from-source top-level generic signature rewrite system");
285+
286+
System.minimizeRewriteSystem();
287+
288+
if (Dump) {
289+
llvm::dbgs() << "Minimized rewrite system:\n";
290+
dump(llvm::dbgs());
291+
}
292+
293+
auto rules = System.getMinimizedGenericSignatureRules();
294+
return buildRequirementsFromRules(rules, getGenericParams());
295+
}

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ class RewriteSystem final {
340340
void minimizeRewriteSystem();
341341

342342
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>>
343-
getMinimizedRules(ArrayRef<const ProtocolDecl *> protos);
343+
getMinimizedProtocolRules(ArrayRef<const ProtocolDecl *> protos) const;
344+
345+
std::vector<unsigned> getMinimizedGenericSignatureRules() const;
344346

345347
void verifyRewriteLoops() const;
346348

0 commit comments

Comments
 (0)