Skip to content

Commit 863f63a

Browse files
committed
RequirementMachine: Fixes to get stdlib to build
1 parent efae2bf commit 863f63a

File tree

3 files changed

+67
-22
lines changed

3 files changed

+67
-22
lines changed

include/swift/AST/RewriteSystem.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ class Atom final {
114114
return (lhs.Proto == rhs.Proto &&
115115
lhs.Value == rhs.Value);
116116
}
117+
118+
friend bool operator!=(Atom lhs, Atom rhs) {
119+
return !(lhs == rhs);
120+
}
117121
};
118122

119123
class Term final {
@@ -196,16 +200,26 @@ class Rule final {
196200
return LHS.size();
197201
}
198202

203+
int compare(const Rule &other,
204+
ProtocolOrder protocolOrder) const {
205+
return LHS.compare(other.LHS, protocolOrder);
206+
}
207+
199208
void dump(llvm::raw_ostream &out) const;
200209
};
201210

202211
class RewriteSystem final {
203212
std::vector<Rule> Rules;
204213
ProtocolOrder Order;
205-
bool Debug = false;
214+
215+
unsigned DebugSimplify : 1;
216+
unsigned DebugAdd : 1;
206217

207218
public:
208-
explicit RewriteSystem(ProtocolOrder order) : Order(order) {}
219+
explicit RewriteSystem(ProtocolOrder order) : Order(order) {
220+
DebugSimplify = false;
221+
DebugAdd = false;
222+
}
209223

210224
RewriteSystem(const RewriteSystem &) = delete;
211225
RewriteSystem(RewriteSystem &&) = delete;

lib/AST/RequirementMachine.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct ProtocolInfo {
6363
struct ProtocolGraph {
6464
llvm::DenseMap<const ProtocolDecl *, ProtocolInfo> Info;
6565
std::vector<const ProtocolDecl *> Protocols;
66+
bool Debug = false;
6667

6768
void visitRequirements(ArrayRef<Requirement> reqs) {
6869
for (auto req : reqs) {
@@ -105,7 +106,7 @@ struct ProtocolGraph {
105106
std::sort(
106107
Protocols.begin(), Protocols.end(),
107108
[&](const ProtocolDecl *lhs,
108-
const ProtocolDecl *rhs) -> int {
109+
const ProtocolDecl *rhs) -> bool {
109110
const auto &lhsInfo = getProtocolInfo(lhs);
110111
const auto &rhsInfo = getProtocolInfo(rhs);
111112

@@ -114,14 +115,23 @@ struct ProtocolGraph {
114115
//
115116
// Derived < Base in the linear order.
116117
if (lhsInfo.Depth != rhsInfo.Depth)
117-
return lhsInfo.Depth - rhsInfo.Depth;
118+
return lhsInfo.Depth > rhsInfo.Depth;
118119

119-
return TypeDecl::compare(lhs, rhs);
120+
return TypeDecl::compare(lhs, rhs) < 0;
120121
});
121122

122123
for (unsigned i : indices(Protocols)) {
123124
Info[Protocols[i]].Index = i;
124125
}
126+
127+
if (Debug) {
128+
for (const auto *proto : Protocols) {
129+
const auto &info = getProtocolInfo(proto);
130+
llvm::dbgs() << "@ Protocol " << proto->getName()
131+
<< " Depth=" << info.Depth
132+
<< " Index=" << info.Index << "\n";
133+
}
134+
}
125135
}
126136

127137
void computeInheritedAssociatedTypes() {
@@ -197,6 +207,10 @@ void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
197207
Protocols.computeInheritedAssociatedTypes();
198208

199209
for (auto *proto : Protocols.Protocols) {
210+
if (Context.LangOpts.DebugRequirementMachine) {
211+
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
212+
}
213+
200214
const auto &info = Protocols.getProtocolInfo(proto);
201215

202216
for (auto *type : info.AssociatedTypes)
@@ -317,12 +331,10 @@ struct RequirementMachine::Implementation {
317331
Implementation()
318332
: Order([&](const ProtocolDecl *lhs,
319333
const ProtocolDecl *rhs) -> int {
320-
auto infoLHS = Protocols.Info.find(lhs);
321-
assert(infoLHS != Protocols.Info.end());
322-
auto infoRHS = Protocols.Info.find(rhs);
323-
assert(infoRHS != Protocols.Info.end());
334+
const auto &infoLHS = Protocols.getProtocolInfo(lhs);
335+
const auto &infoRHS = Protocols.getProtocolInfo(rhs);
324336

325-
return infoRHS->second.Index - infoLHS->second.Index;
337+
return infoLHS.Index - infoRHS.Index;
326338
}),
327339
System(Order) {}
328340
};
@@ -351,6 +363,11 @@ void RequirementMachine::addGenericSignature(CanGenericSignature sig) {
351363

352364
Impl->Protocols = builder.Protocols;
353365

366+
std::sort(builder.Rules.begin(), builder.Rules.end(),
367+
[&](std::pair<Term, Term> lhs,
368+
std::pair<Term, Term> rhs) -> int {
369+
return lhs.first.compare(rhs.first, Impl->Order) < 0;
370+
});
354371
for (const auto &rule : builder.Rules)
355372
Impl->System.addRule(rule.first, rule.second);
356373

@@ -364,14 +381,14 @@ void RequirementMachine::addGenericSignature(CanGenericSignature sig) {
364381
case RewriteSystem::CompletionResult::MaxIterations:
365382
llvm::errs() << "Generic signature " << sig
366383
<< " exceeds maximum completion step count\n";
367-
break;
368-
// abort();
384+
Impl->System.dump(llvm::errs());
385+
abort();
369386

370387
case RewriteSystem::CompletionResult::MaxDepth:
371388
llvm::errs() << "Generic signature " << sig
372389
<< " exceeds maximum completion depth\n";
373-
break;
374-
// abort();
390+
Impl->System.dump(llvm::errs());
391+
abort();
375392
}
376393

377394
markComplete();

lib/AST/RewriteSystem.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "swift/AST/RewriteSystem.h"
1414
#include "llvm/Support/raw_ostream.h"
1515
#include <algorithm>
16+
#include <deque>
1617
#include <vector>
1718

1819
using namespace swift;
@@ -42,7 +43,7 @@ int Atom::compare(Atom other, ProtocolOrder protocolOrder) const {
4243

4344
case Kind::GenericParam: {
4445
auto *param = getGenericParam();
45-
auto *otherParam = getGenericParam();
46+
auto *otherParam = other.getGenericParam();
4647

4748
if (param->getDepth() != otherParam->getDepth())
4849
return param->getDepth() < otherParam->getDepth() ? -1 : 1;
@@ -90,8 +91,12 @@ int Term::compare(const Term &other, ProtocolOrder protocolOrder) const {
9091
auto rhs = other[i];
9192

9293
int result = lhs.compare(rhs, protocolOrder);
93-
if (result != 0)
94+
if (result != 0) {
95+
assert(lhs != rhs);
9496
return result;
97+
}
98+
99+
assert(lhs == rhs);
95100
}
96101

97102
return 0;
@@ -201,6 +206,15 @@ bool RewriteSystem::addRule(Term lhs, Term rhs) {
201206
if (result < 0)
202207
std::swap(lhs, rhs);
203208

209+
assert(lhs.compare(rhs, Order) > 0);
210+
211+
if (DebugAdd) {
212+
llvm::dbgs() << "# Adding rule ";
213+
lhs.dump(llvm::dbgs());
214+
llvm::dbgs() << " => ";
215+
rhs.dump(llvm::dbgs());
216+
llvm::dbgs() << "\n";
217+
}
204218
Rules.emplace_back(lhs, rhs);
205219

206220
return true;
@@ -209,7 +223,7 @@ bool RewriteSystem::addRule(Term lhs, Term rhs) {
209223
bool RewriteSystem::simplify(Term &term) const {
210224
bool changed = false;
211225

212-
if (Debug) {
226+
if (DebugSimplify) {
213227
llvm::dbgs() << "= Term ";
214228
term.dump(llvm::dbgs());
215229
llvm::dbgs() << "\n";
@@ -221,14 +235,14 @@ bool RewriteSystem::simplify(Term &term) const {
221235
if (rule.isDeleted())
222236
continue;
223237

224-
if (Debug) {
238+
if (DebugSimplify) {
225239
llvm::dbgs() << "== Rule ";
226240
rule.dump(llvm::dbgs());
227241
llvm::dbgs() << "\n";
228242
}
229243

230244
if (rule.apply(term)) {
231-
if (Debug) {
245+
if (DebugSimplify) {
232246
llvm::dbgs() << "=== Result ";
233247
term.dump(llvm::dbgs());
234248
llvm::dbgs() << "\n";
@@ -250,7 +264,7 @@ RewriteSystem::CompletionResult
250264
RewriteSystem::computeConfluentCompletion(
251265
unsigned maxIterations,
252266
unsigned maxDepth) {
253-
SmallVector<std::pair<unsigned, unsigned>, 16> worklist;
267+
std::deque<std::pair<unsigned, unsigned>> worklist;
254268

255269
for (unsigned i : indices(Rules)) {
256270
for (unsigned j : indices(Rules)) {
@@ -262,8 +276,8 @@ RewriteSystem::computeConfluentCompletion(
262276
}
263277

264278
while (!worklist.empty()) {
265-
auto pair = worklist.back();
266-
worklist.pop_back();
279+
auto pair = worklist.front();
280+
worklist.pop_front();
267281

268282
Term first;
269283

0 commit comments

Comments
 (0)