Skip to content

Commit 793474f

Browse files
committed
RequirementMachine: Associated type inheritance
1 parent 69b7a64 commit 793474f

File tree

3 files changed

+192
-30
lines changed

3 files changed

+192
-30
lines changed

include/swift/AST/RewriteSystem.h

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,35 @@ namespace swift {
2828
namespace rewriting {
2929

3030
class Atom final {
31-
using Storage = llvm::PointerUnion<Identifier,
32-
const ProtocolDecl *,
33-
const AssociatedTypeDecl *,
34-
GenericTypeParamType *>;
31+
using Storage = llvm::PointerUnion<Identifier, GenericTypeParamType *>;
32+
33+
const ProtocolDecl *Proto;
3534
Storage Value;
3635

37-
explicit Atom(Storage value) : Value(value) {}
36+
explicit Atom(const ProtocolDecl *proto, Storage value)
37+
: Proto(proto), Value(value) {
38+
// Triggers assertion if the atom is not valid.
39+
(void) getKind();
40+
}
3841

3942
public:
4043
static Atom forName(Identifier name) {
41-
return Atom(name);
44+
return Atom(nullptr, name);
4245
}
4346

4447
static Atom forProtocol(const ProtocolDecl *proto) {
45-
assert(proto != nullptr);
46-
return Atom(proto);
48+
return Atom(proto, Storage());
4749
}
4850

49-
static Atom forAssociatedType(const AssociatedTypeDecl *type) {
50-
assert(type != nullptr);
51-
return Atom(type);
51+
static Atom forAssociatedType(const ProtocolDecl *proto,
52+
Identifier name) {
53+
assert(proto != nullptr);
54+
return Atom(proto, name);
5255
}
5356

5457
static Atom forGenericParam(GenericTypeParamType *param) {
5558
assert(param->isCanonical());
56-
return Atom(param);
59+
return Atom(nullptr, param);
5760
}
5861

5962
enum class Kind : uint8_t {
@@ -64,30 +67,39 @@ class Atom final {
6467
};
6568

6669
Kind getKind() const {
67-
if (Value.is<Identifier>())
68-
return Kind::Name;
69-
if (Value.is<const ProtocolDecl *>())
70+
if (!Value) {
71+
assert(Proto != nullptr);
7072
return Kind::Protocol;
71-
if (Value.is<const AssociatedTypeDecl *>())
72-
return Kind::AssociatedType;
73-
if (Value.is<GenericTypeParamType *>())
73+
}
74+
75+
if (Value.is<Identifier>()) {
76+
if (Proto != nullptr)
77+
return Kind::AssociatedType;
78+
return Kind::Name;
79+
}
80+
81+
if (Value.is<GenericTypeParamType *>()) {
82+
assert(Proto == nullptr);
7483
return Kind::GenericParam;
84+
}
85+
7586
llvm_unreachable("Bad term rewriting atom");
7687
}
7788

7889
Identifier getName() const {
90+
assert(getKind() == Kind::Name ||
91+
getKind() == Kind::AssociatedType);
7992
return Value.get<Identifier>();
8093
}
8194

8295
const ProtocolDecl *getProtocol() const {
83-
return Value.get<const ProtocolDecl *>();
84-
}
85-
86-
const AssociatedTypeDecl *getAssociatedType() const {
87-
return Value.get<const AssociatedTypeDecl *>();
96+
assert(getKind() == Kind::Protocol ||
97+
getKind() == Kind::AssociatedType);
98+
return Proto;
8899
}
89100

90101
GenericTypeParamType *getGenericParam() const {
102+
assert(getKind() == Kind::GenericParam);
91103
return Value.get<GenericTypeParamType *>();
92104
}
93105

@@ -96,7 +108,8 @@ class Atom final {
96108
void dump(llvm::raw_ostream &out) const;
97109

98110
friend bool operator==(Atom lhs, Atom rhs) {
99-
return lhs.Value == rhs.Value;
111+
return (lhs.Proto == rhs.Proto &&
112+
lhs.Value == rhs.Value);
100113
}
101114
};
102115

@@ -183,6 +196,7 @@ class Rule final {
183196

184197
class RewriteSystem final {
185198
std::vector<Rule> Rules;
199+
bool Debug = false;
186200

187201
public:
188202
bool addRule(Term lhs, Term rhs);

lib/AST/RequirementMachine.cpp

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,34 @@ struct ProtocolInfo {
3030
ArrayRef<ProtocolDecl *> Inherited;
3131
llvm::TinyPtrVector<AssociatedTypeDecl *> AssociatedTypes;
3232
ArrayRef<Requirement> Requirements;
33+
34+
// Used by computeDepth() to detect circularity.
35+
unsigned Mark : 1;
36+
37+
// Longest chain of protocol refinements, including this one.
38+
// Greater than zero on valid code, might be zero if there's
39+
// a cycle.
40+
unsigned Depth : 31;
41+
42+
// Index of the protocol in the linear order.
43+
unsigned Index : 32;
44+
45+
ProtocolInfo() {
46+
Mark = 0;
47+
Depth = 0;
48+
Index = 0;
49+
}
50+
51+
ProtocolInfo(ArrayRef<ProtocolDecl *> inherited,
52+
llvm::TinyPtrVector<AssociatedTypeDecl *> &&types,
53+
ArrayRef<Requirement> reqs)
54+
: Inherited(inherited),
55+
AssociatedTypes(types),
56+
Requirements(reqs) {
57+
Mark = 0;
58+
Depth = 0;
59+
Index = 0;
60+
}
3361
};
3462

3563
struct ProtocolGraph {
@@ -61,6 +89,79 @@ struct ProtocolGraph {
6189
visitRequirements(Info[proto].Requirements);
6290
}
6391
}
92+
93+
void computeLinearOrder() {
94+
for (const auto *proto : Protocols) {
95+
(void) computeProtocolDepth(proto);
96+
}
97+
98+
std::sort(
99+
Protocols.begin(), Protocols.end(),
100+
[&](const ProtocolDecl *lhs,
101+
const ProtocolDecl *rhs) -> int {
102+
const auto &lhsInfo = Info[lhs];
103+
const auto &rhsInfo = Info[rhs];
104+
105+
// protocol Base {} // depth 1
106+
// protocol Derived : Base {} // depth 2
107+
//
108+
// Derived < Base in the linear order.
109+
if (lhsInfo.Depth != rhsInfo.Depth)
110+
return lhsInfo.Depth - rhsInfo.Depth;
111+
112+
return TypeDecl::compare(lhs, rhs);
113+
});
114+
115+
for (unsigned i : indices(Protocols)) {
116+
Info[Protocols[i]].Index = i;
117+
}
118+
}
119+
120+
void computeInheritedAssociatedTypes() {
121+
for (const auto *proto : Protocols) {
122+
auto &info = Info[proto];
123+
124+
llvm::SmallDenseSet<const AssociatedTypeDecl *, 4> visited;
125+
for (const auto *inherited : info.Inherited) {
126+
if (inherited == proto)
127+
continue;
128+
129+
for (auto *inheritedType : Info[inherited].AssociatedTypes) {
130+
if (!visited.insert(inheritedType).second)
131+
continue;
132+
133+
// The 'if (inherited == proto)' above avoids a potential
134+
// iterator invalidation here.
135+
info.AssociatedTypes.push_back(inheritedType);
136+
}
137+
}
138+
}
139+
}
140+
141+
private:
142+
unsigned computeProtocolDepth(const ProtocolDecl *proto) {
143+
auto &info = Info[proto];
144+
145+
if (info.Mark) {
146+
// Already computed, or we have a cycle. Cycles are diagnosed
147+
// elsewhere in the type checker, so we don't have to do
148+
// anything here.
149+
return info.Depth;
150+
}
151+
152+
info.Mark = true;
153+
unsigned depth = 0;
154+
155+
for (auto *inherited : info.Inherited) {
156+
unsigned inheritedDepth = computeProtocolDepth(inherited);
157+
depth = std::max(inheritedDepth, depth);
158+
}
159+
160+
depth++;
161+
162+
info.Depth = depth;
163+
return depth;
164+
}
64165
};
65166

66167
class RewriteSystemBuilder {
@@ -73,6 +174,9 @@ class RewriteSystemBuilder {
73174
void addGenericSignature(CanGenericSignature sig);
74175
void addAssociatedType(const AssociatedTypeDecl *type,
75176
const ProtocolDecl *proto);
177+
void addInheritedAssociatedType(const AssociatedTypeDecl *type,
178+
const ProtocolDecl *inherited,
179+
const ProtocolDecl *proto);
76180
void addRequirement(const Requirement &req,
77181
const ProtocolDecl *proto);
78182

@@ -85,6 +189,8 @@ void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
85189
ProtocolGraph graph;
86190
graph.visitRequirements(sig->getRequirements());
87191
graph.computeTransitiveClosure();
192+
graph.computeLinearOrder();
193+
graph.computeInheritedAssociatedTypes();
88194

89195
for (auto *proto : graph.Protocols) {
90196
if (Context.LangOpts.DebugRequirementMachine) {
@@ -96,6 +202,12 @@ void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
96202
for (auto *type : info.AssociatedTypes)
97203
addAssociatedType(type, proto);
98204

205+
for (auto *inherited : info.Inherited) {
206+
for (auto *inheritedType : graph.Info[inherited].AssociatedTypes) {
207+
addInheritedAssociatedType(inheritedType, inherited, proto);
208+
}
209+
}
210+
99211
for (auto req : info.Requirements)
100212
addRequirement(req.getCanonical(), proto);
101213

@@ -115,7 +227,21 @@ void RewriteSystemBuilder::addAssociatedType(const AssociatedTypeDecl *type,
115227
lhs.add(Atom::forName(type->getName()));
116228

117229
Term rhs;
118-
rhs.add(Atom::forAssociatedType(type));
230+
rhs.add(Atom::forAssociatedType(proto, type->getName()));
231+
232+
Rules.emplace_back(lhs, rhs);
233+
}
234+
235+
void RewriteSystemBuilder::addInheritedAssociatedType(
236+
const AssociatedTypeDecl *type,
237+
const ProtocolDecl *inherited,
238+
const ProtocolDecl *proto) {
239+
Term lhs;
240+
lhs.add(Atom::forProtocol(proto));
241+
lhs.add(Atom::forAssociatedType(inherited, type->getName()));
242+
243+
Term rhs;
244+
rhs.add(Atom::forAssociatedType(proto, type->getName()));
119245

120246
Rules.emplace_back(lhs, rhs);
121247
}

lib/AST/RewriteSystem.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ int Atom::compare(Atom other) const {
3232
case Kind::Protocol:
3333
return TypeDecl::compare(getProtocol(), other.getProtocol());
3434

35-
case Kind::AssociatedType:
36-
return TypeDecl::compare(getAssociatedType(), other.getAssociatedType());
35+
case Kind::AssociatedType: {
36+
int result = TypeDecl::compare(getProtocol(), other.getProtocol());
37+
if (result)
38+
return result;
39+
40+
return getName().compare(other.getName());
41+
}
3742

3843
case Kind::GenericParam: {
3944
auto *param = getGenericParam();
@@ -63,9 +68,8 @@ void Atom::dump(llvm::raw_ostream &out) const {
6368
return;
6469

6570
case Kind::AssociatedType: {
66-
auto *type = getAssociatedType();
67-
out << "[" << type->getProtocol()->getName()
68-
<< ":" << type->getName() << "]";
71+
out << "[" << getProtocol()->getName()
72+
<< ":" << getName() << "]";
6973
return;
7074
}
7175

@@ -206,13 +210,31 @@ bool RewriteSystem::addRule(Term lhs, Term rhs) {
206210
bool RewriteSystem::simplify(Term &term) const {
207211
bool changed = false;
208212

213+
if (Debug) {
214+
llvm::dbgs() << "= Term ";
215+
term.dump(llvm::dbgs());
216+
llvm::dbgs() << "\n";
217+
}
218+
209219
while (true) {
210220
bool tryAgain = false;
211221
for (const auto &rule : Rules) {
212222
if (rule.isDeleted())
213223
continue;
214224

225+
if (Debug) {
226+
llvm::dbgs() << "== Rule ";
227+
rule.dump(llvm::dbgs());
228+
llvm::dbgs() << "\n";
229+
}
230+
215231
if (rule.apply(term)) {
232+
if (Debug) {
233+
llvm::dbgs() << "=== Result ";
234+
term.dump(llvm::dbgs());
235+
llvm::dbgs() << "\n";
236+
}
237+
216238
changed = true;
217239
tryAgain = true;
218240
}

0 commit comments

Comments
 (0)