Skip to content

Commit 62c9347

Browse files
committed
RequirementMachine: Overhaul generating conformances algorithm
1 parent 2db676d commit 62c9347

File tree

2 files changed

+170
-27
lines changed

2 files changed

+170
-27
lines changed

lib/AST/RequirementMachine/GeneratingConformances.cpp

Lines changed: 160 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ void HomotopyGenerator::findProtocolConformanceRules(
5656
SmallVectorImpl<std::pair<MutableTerm, unsigned>> &inContext,
5757
const RewriteSystem &system) const {
5858

59-
MutableTerm term = Basepoint;
59+
auto redundancyCandidates = Path.findRulesAppearingOnceInEmptyContext();
60+
if (redundancyCandidates.empty())
61+
return;
6062

6163
for (const auto &step : Path) {
6264
switch (step.Kind) {
@@ -65,14 +67,46 @@ void HomotopyGenerator::findProtocolConformanceRules(
6567
if (!rule.isProtocolConformanceRule())
6668
break;
6769

68-
if (!step.isInContext()) {
69-
assert(std::find(notInContext.begin(),
70-
notInContext.end(),
71-
step.RuleID) == notInContext.end() &&
72-
"A conformance rule appears more than once without context?");
70+
if (!step.isInContext() &&
71+
step.Inverse &&
72+
std::find(redundancyCandidates.begin(),
73+
redundancyCandidates.end(),
74+
step.RuleID) != redundancyCandidates.end()) {
7375
notInContext.push_back(step.RuleID);
74-
} else if (step.EndOffset == 0) {
75-
assert(step.StartOffset > 0);
76+
}
77+
78+
break;
79+
}
80+
81+
case RewriteStep::AdjustConcreteType:
82+
break;
83+
}
84+
}
85+
86+
if (notInContext.empty())
87+
return;
88+
89+
if (notInContext.size() > 1) {
90+
llvm::errs() << "Multiple conformance rules appear once without context:\n";
91+
for (unsigned ruleID : notInContext)
92+
llvm::errs() << system.getRule(ruleID) << "\n";
93+
dump(llvm::errs(), system);
94+
llvm::errs() << "\n";
95+
abort();
96+
}
97+
98+
MutableTerm term = Basepoint;
99+
100+
for (const auto &step : Path) {
101+
switch (step.Kind) {
102+
case RewriteStep::ApplyRewriteRule: {
103+
const auto &rule = system.getRule(step.RuleID);
104+
if (!rule.isProtocolConformanceRule())
105+
break;
106+
107+
if (step.StartOffset > 0 &&
108+
step.EndOffset == 0 &&
109+
rule.getLHS().back() == system.getRule(notInContext[0]).getLHS().back()) {
76110
MutableTerm prefix(term.begin(), term.begin() + step.StartOffset);
77111
inContext.emplace_back(prefix, step.RuleID);
78112
}
@@ -86,8 +120,13 @@ void HomotopyGenerator::findProtocolConformanceRules(
86120
step.apply(term, system);
87121
}
88122

89-
if (notInContext.size() == 1 && inContext.empty()) {
90-
llvm::errs() << "A conformance rule not based on another conformance rule?\n";
123+
if (inContext.empty()) {
124+
notInContext.clear();
125+
return;
126+
}
127+
128+
if (inContext.size() > 1) {
129+
llvm::errs() << "Multiple candidate conformance rules in context?\n";
91130
dump(llvm::errs(), system);
92131
llvm::errs() << "\n";
93132
abort();
@@ -252,6 +291,8 @@ void RewriteSystem::computeCandidateConformancePaths(
252291
unsigned ruleID = pair.second;
253292
llvm::dbgs() << " (#" << ruleID << ") " << getRule(ruleID) << "\n";
254293
}
294+
295+
llvm::dbgs() << "\n";
255296
}
256297

257298
// Suppose a 3-cell contains a conformance rule (T.[P] => T) in an empty
@@ -354,43 +395,111 @@ bool RewriteSystem::isValidConformancePath(
354395
return true;
355396
}
356397

398+
void RewriteSystem::dumpGeneratingConformanceEquation(
399+
llvm::raw_ostream &out,
400+
unsigned baseRuleID,
401+
const std::vector<SmallVector<unsigned, 2>> &paths) const {
402+
out << getRule(baseRuleID).getLHS() << " := ";
403+
404+
bool first = true;
405+
for (const auto &path : paths) {
406+
if (!first)
407+
out << "";
408+
else
409+
first = false;
410+
for (unsigned ruleID : path)
411+
out << "(" << getRule(ruleID).getLHS() << ")";
412+
}
413+
}
414+
415+
void RewriteSystem::verifyGeneratingConformanceEquations(
416+
const llvm::MapVector<unsigned,
417+
std::vector<SmallVector<unsigned, 2>>>
418+
&conformancePaths) const {
419+
#ifndef NDEBUG
420+
for (const auto &pair : conformancePaths) {
421+
const auto &rule = getRule(pair.first);
422+
auto *proto = rule.getLHS().back().getProtocol();
423+
424+
MutableTerm baseTerm(rule.getLHS());
425+
(void) simplify(baseTerm);
426+
427+
for (const auto &path : pair.second) {
428+
const auto &otherRule = getRule(path.back());
429+
auto *otherProto = otherRule.getLHS().back().getProtocol();
430+
431+
if (proto != otherProto) {
432+
llvm::errs() << "Invalid equation: ";
433+
dumpGeneratingConformanceEquation(llvm::errs(),
434+
pair.first, pair.second);
435+
llvm::errs() << "\n";
436+
llvm::errs() << "Mismatched conformance:\n";
437+
llvm::errs() << "Base rule: " << rule << "\n";
438+
llvm::errs() << "Final rule: " << otherRule << "\n\n";
439+
dump(llvm::errs());
440+
abort();
441+
}
442+
443+
MutableTerm otherTerm;
444+
for (unsigned otherRuleID : path) {
445+
otherTerm.append(getRule(otherRuleID).getLHS());
446+
}
447+
448+
(void) simplify(otherTerm);
449+
450+
if (baseTerm != otherTerm) {
451+
llvm::errs() << "Invalid equation: ";
452+
llvm::errs() << "\n";
453+
dumpGeneratingConformanceEquation(llvm::errs(),
454+
pair.first, pair.second);
455+
llvm::errs() << "Invalid conformance path:\n";
456+
llvm::errs() << "Expected: " << baseTerm << "\n";
457+
llvm::errs() << "Got: " << otherTerm << "\n\n";
458+
dump(llvm::errs());
459+
abort();
460+
}
461+
}
462+
}
463+
#endif
464+
}
465+
357466
/// Computes a minimal set of generating conformances, assuming that homotopy
358467
/// reduction has already eliminated all redundant rewrite rules that are not
359468
/// conformance rules.
360469
void RewriteSystem::computeGeneratingConformances(
361470
llvm::DenseSet<unsigned> &redundantConformances) {
362471
llvm::MapVector<unsigned, std::vector<SmallVector<unsigned, 2>>> conformancePaths;
363472

473+
// Prepare the initial set of equations: every non-redundant conformance rule
474+
// can be expressed as itself.
364475
for (unsigned ruleID : indices(Rules)) {
365476
const auto &rule = getRule(ruleID);
366-
if (rule.isProtocolConformanceRule()) {
367-
SmallVector<unsigned, 2> path;
368-
path.push_back(ruleID);
369-
conformancePaths[ruleID].push_back(path);
370-
}
477+
if (rule.isRedundant())
478+
continue;
479+
480+
if (!rule.isProtocolConformanceRule())
481+
continue;
482+
483+
SmallVector<unsigned, 2> path;
484+
path.push_back(ruleID);
485+
conformancePaths[ruleID].push_back(path);
371486
}
372487

373488
computeCandidateConformancePaths(conformancePaths);
374489

375490
if (Debug.contains(DebugFlags::GeneratingConformances)) {
376491
llvm::dbgs() << "Initial set of equations:\n";
377492
for (const auto &pair : conformancePaths) {
378-
llvm::dbgs() << "- " << getRule(pair.first).getLHS() << " := ";
379-
380-
bool first = true;
381-
for (const auto &path : pair.second) {
382-
if (!first)
383-
llvm::dbgs() << "";
384-
else
385-
first = false;
386-
for (unsigned ruleID : path)
387-
llvm::dbgs() << "(" << getRule(ruleID).getLHS() << ")";
388-
}
389-
493+
llvm::dbgs() << "- ";
494+
dumpGeneratingConformanceEquation(llvm::dbgs(),
495+
pair.first, pair.second);
390496
llvm::dbgs() << "\n";
391497
}
392498
}
393499

500+
verifyGeneratingConformanceEquations(conformancePaths);
501+
502+
// Find a minimal set of generating conformances.
394503
for (const auto &pair : conformancePaths) {
395504
for (const auto &path : pair.second) {
396505
llvm::SmallDenseSet<unsigned, 4> visited;
@@ -404,6 +513,30 @@ void RewriteSystem::computeGeneratingConformances(
404513
}
405514
}
406515

516+
// Check invariants.
517+
#ifndef NDEBUG
518+
for (const auto &pair : conformancePaths) {
519+
if (redundantConformances.count(pair.first) > 0)
520+
continue;
521+
522+
const auto &rule = getRule(pair.first);
523+
524+
if (rule.isRedundant()) {
525+
llvm::errs() << "Generating conformance is redundant: ";
526+
llvm::errs() << rule << "\n\n";
527+
dump(llvm::errs());
528+
abort();
529+
}
530+
531+
if (rule.containsUnresolvedSymbols()) {
532+
llvm::errs() << "Generating conformance contains unresolved symbols: ";
533+
llvm::errs() << rule << "\n\n";
534+
dump(llvm::errs());
535+
abort();
536+
}
537+
}
538+
#endif
539+
407540
if (Debug.contains(DebugFlags::GeneratingConformances)) {
408541
llvm::dbgs() << "Generating conformances:\n";
409542

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,16 @@ class RewriteSystem final {
495495
std::vector<SmallVector<unsigned, 2>>>
496496
&conformancePaths) const;
497497

498+
void dumpGeneratingConformanceEquation(
499+
llvm::raw_ostream &out,
500+
unsigned baseRuleID,
501+
const std::vector<SmallVector<unsigned, 2>> &paths) const;
502+
503+
void verifyGeneratingConformanceEquations(
504+
const llvm::MapVector<unsigned,
505+
std::vector<SmallVector<unsigned, 2>>>
506+
&conformancePaths) const;
507+
498508
void computeGeneratingConformances(
499509
llvm::DenseSet<unsigned> &redundantConformances);
500510

0 commit comments

Comments
 (0)