Skip to content

Commit 20df303

Browse files
committed
RequirementMachine: Unique Atoms
1 parent 332c65f commit 20df303

File tree

5 files changed

+370
-149
lines changed

5 files changed

+370
-149
lines changed

include/swift/AST/RequirementMachine.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,6 @@ class GenericSignature;
2323
class ProtocolDecl;
2424
class Requirement;
2525

26-
namespace rewriting {
27-
28-
class Term;
29-
30-
Term getTermForType(CanType paramType, const ProtocolDecl *proto);
31-
32-
} // end namespace rewriting
33-
3426
/// Wraps a rewrite system with higher-level operations in terms of
3527
/// generic signatures and interface types.
3628
class RequirementMachine final {

include/swift/AST/RewriteSystem.h

Lines changed: 182 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
#include "swift/AST/LayoutConstraint.h"
1919
#include "swift/AST/ProtocolGraph.h"
2020
#include "swift/AST/Types.h"
21+
#include "llvm/ADT/FoldingSet.h"
2122
#include "llvm/ADT/PointerUnion.h"
2223
#include "llvm/ADT/SmallVector.h"
2324
#include "llvm/ADT/TinyPtrVector.h"
25+
#include "llvm/Support/Allocator.h"
26+
#include "llvm/Support/TrailingObjects.h"
2427
#include <algorithm>
2528

2629
namespace llvm {
@@ -31,7 +34,9 @@ namespace swift {
3134

3235
namespace rewriting {
3336

34-
/// The most primitive term in the rewrite system.
37+
class RewriteContext;
38+
39+
/// The smallest element in the rewrite system.
3540
///
3641
/// enum Atom {
3742
/// case name(Identifier)
@@ -43,61 +48,7 @@ namespace rewriting {
4348
///
4449
/// Out-of-line methods are documented in RewriteSystem.cpp.
4550
class Atom final {
46-
using Storage = llvm::PointerUnion<Identifier,
47-
GenericTypeParamType *,
48-
LayoutConstraint>;
49-
50-
llvm::TinyPtrVector<const ProtocolDecl *> Protos;
51-
Storage Value;
52-
53-
explicit Atom(llvm::TinyPtrVector<const ProtocolDecl *> protos,
54-
Storage value)
55-
: Protos(protos), Value(value) {
56-
// Triggers assertion if the atom is not valid.
57-
(void) getKind();
58-
}
59-
6051
public:
61-
/// Creates a new name atom.
62-
static Atom forName(Identifier name) {
63-
return Atom({}, name);
64-
}
65-
66-
/// Creates a new protocol atom.
67-
static Atom forProtocol(const ProtocolDecl *proto) {
68-
return Atom({proto}, Storage());
69-
}
70-
71-
/// Creates a new associated type atom for a single protocol.
72-
static Atom forAssociatedType(const ProtocolDecl *proto,
73-
Identifier name) {
74-
assert(proto != nullptr);
75-
return Atom({proto}, name);
76-
}
77-
78-
/// Creates a merged associated type atom to represent a nested
79-
/// type that conforms to multiple protocols, all of which have
80-
/// an associated type with the same name.
81-
static Atom forAssociatedType(
82-
llvm::TinyPtrVector<const ProtocolDecl *>protos,
83-
Identifier name) {
84-
assert(!protos.empty());
85-
return Atom(protos, name);
86-
}
87-
88-
/// Creates a generic parameter atom, representing a generic
89-
/// parameter in the top-level generic signature from which the
90-
/// rewrite system is built.
91-
static Atom forGenericParam(GenericTypeParamType *param) {
92-
assert(param->isCanonical());
93-
return Atom({}, param);
94-
}
95-
96-
/// Creates a layout atom, representing a layout constraint.
97-
static Atom forLayout(LayoutConstraint layout) {
98-
return Atom({}, layout);
99-
}
100-
10152
enum class Kind : uint8_t {
10253
/// An associated type [P:T] or [P&Q&...:T]. The parent term
10354
/// must be known to conform to P (or P, Q, ...).
@@ -124,76 +75,190 @@ class Atom final {
12475
Layout
12576
};
12677

127-
Kind getKind() const {
128-
if (!Value) {
129-
assert(Protos.size() == 1);
130-
return Kind::Protocol;
78+
private:
79+
friend class RewriteContext;
80+
81+
/// Atoms are uniqued and immutable, stored as a single pointer;
82+
/// the Storage type is the allocated backing storage.
83+
struct Storage final
84+
: public llvm::FoldingSetNode,
85+
public llvm::TrailingObjects<Storage, const ProtocolDecl *> {
86+
friend class Atom;
87+
88+
unsigned Kind : 16;
89+
unsigned NumProtocols : 16;
90+
91+
union {
92+
Identifier Name;
93+
LayoutConstraint Layout;
94+
const ProtocolDecl *Proto;
95+
GenericTypeParamType *GenericParam;
96+
};
97+
98+
explicit Storage(Identifier name) {
99+
Kind = unsigned(Atom::Kind::Name);
100+
NumProtocols = 0;
101+
Name = name;
102+
}
103+
104+
explicit Storage(LayoutConstraint layout) {
105+
Kind = unsigned(Atom::Kind::Layout);
106+
NumProtocols = 0;
107+
Layout = layout;
108+
}
109+
110+
explicit Storage(const ProtocolDecl *proto) {
111+
Kind = unsigned(Atom::Kind::Protocol);
112+
NumProtocols = 0;
113+
Proto = proto;
131114
}
132115

133-
if (Value.is<Identifier>()) {
134-
if (!Protos.empty())
135-
return Kind::AssociatedType;
136-
return Kind::Name;
116+
explicit Storage(GenericTypeParamType *param) {
117+
Kind = unsigned(Atom::Kind::GenericParam);
118+
NumProtocols = 0;
119+
GenericParam = param;
137120
}
138121

139-
if (Value.is<GenericTypeParamType *>()) {
140-
assert(Protos.empty());
141-
return Kind::GenericParam;
122+
Storage(ArrayRef<const ProtocolDecl *> protos, Identifier name) {
123+
assert(!protos.empty());
124+
125+
Kind = unsigned(Atom::Kind::AssociatedType);
126+
NumProtocols = protos.size();
127+
Name = name;
128+
129+
for (unsigned i : indices(protos))
130+
getProtocols()[i] = protos[i];
131+
}
132+
133+
size_t numTrailingObjects(OverloadToken<const ProtocolDecl *>) const {
134+
return NumProtocols;
135+
}
136+
137+
MutableArrayRef<const ProtocolDecl *> getProtocols() {
138+
return {getTrailingObjects<const ProtocolDecl *>(), NumProtocols};
142139
}
143140

144-
if (Value.is<LayoutConstraint>()) {
145-
assert(Protos.empty());
146-
return Kind::Layout;
141+
ArrayRef<const ProtocolDecl *> getProtocols() const {
142+
return {getTrailingObjects<const ProtocolDecl *>(), NumProtocols};
147143
}
148144

149-
llvm_unreachable("Bad term rewriting atom");
145+
void Profile(llvm::FoldingSetNodeID &id) {
146+
id.AddInteger(Kind);
147+
148+
switch (Atom::Kind(Kind)) {
149+
case Atom::Kind::Name:
150+
id.AddPointer(Name.get());
151+
return;
152+
153+
case Atom::Kind::Layout:
154+
id.AddPointer(Layout.getPointer());
155+
return;
156+
157+
case Atom::Kind::Protocol:
158+
id.AddPointer(Proto);
159+
return;
160+
161+
case Atom::Kind::GenericParam:
162+
id.AddPointer(GenericParam);
163+
return;
164+
165+
case Atom::Kind::AssociatedType: {
166+
auto protos = getProtocols();
167+
id.AddInteger(protos.size());
168+
169+
for (const auto *proto : protos)
170+
id.AddPointer(proto);
171+
172+
id.AddPointer(Name.get());
173+
return;
174+
}
175+
}
176+
177+
llvm_unreachable("Bad atom kind");
178+
}
179+
};
180+
181+
private:
182+
const Storage *Ptr;
183+
184+
Atom(const Storage *ptr) : Ptr(ptr) {}
185+
186+
public:
187+
Kind getKind() const {
188+
return Kind(Ptr->Kind);
150189
}
151190

152191
/// Get the identifier associated with an unbound name atom or an
153192
/// associated type atom.
154193
Identifier getName() const {
155194
assert(getKind() == Kind::Name ||
156195
getKind() == Kind::AssociatedType);
157-
return Value.get<Identifier>();
196+
return Ptr->Name;
158197
}
159198

160199
/// Get the single protocol declaration associate with a protocol atom.
161200
const ProtocolDecl *getProtocol() const {
162201
assert(getKind() == Kind::Protocol);
163-
assert(Protos.size() == 1);
164-
return Protos.front();
202+
return Ptr->Proto;
165203
}
166204

167205
/// Get the list of protocols associated with a protocol or associated
168206
/// type atom.
169-
llvm::TinyPtrVector<const ProtocolDecl *> getProtocols() const {
170-
assert(getKind() == Kind::Protocol ||
171-
getKind() == Kind::AssociatedType);
172-
assert(!Protos.empty());
173-
return Protos;
207+
ArrayRef<const ProtocolDecl *> getProtocols() const {
208+
assert(getKind() == Kind::AssociatedType);
209+
auto protos = Ptr->getProtocols();
210+
assert(!protos.empty());
211+
return protos;
174212
}
175213

176214
/// Get the generic parameter associated with a generic parameter atom.
177215
GenericTypeParamType *getGenericParam() const {
178216
assert(getKind() == Kind::GenericParam);
179-
return Value.get<GenericTypeParamType *>();
217+
return Ptr->GenericParam;
180218
}
181219

182220
/// Get the layout constraint associated with a layout constraint atom.
183221
LayoutConstraint getLayoutConstraint() const {
184222
assert(getKind() == Kind::Layout);
185-
return Value.get<LayoutConstraint>();
223+
return Ptr->Layout;
186224
}
187225

226+
/// Creates a new name atom.
227+
static Atom forName(Identifier name,
228+
RewriteContext &ctx);
229+
230+
/// Creates a new protocol atom.
231+
static Atom forProtocol(const ProtocolDecl *proto,
232+
RewriteContext &ctx);
233+
234+
/// Creates a new associated type atom for a single protocol.
235+
static Atom forAssociatedType(const ProtocolDecl *proto,
236+
Identifier name,
237+
RewriteContext &ctx);
238+
239+
/// Creates a merged associated type atom to represent a nested
240+
/// type that conforms to multiple protocols, all of which have
241+
/// an associated type with the same name.
242+
static Atom forAssociatedType(ArrayRef<const ProtocolDecl *> protos,
243+
Identifier name,
244+
RewriteContext &ctx);
245+
246+
/// Creates a generic parameter atom, representing a generic
247+
/// parameter in the top-level generic signature from which the
248+
/// rewrite system is built.
249+
static Atom forGenericParam(GenericTypeParamType *param,
250+
RewriteContext &ctx);
251+
252+
/// Creates a layout atom, representing a layout constraint.
253+
static Atom forLayout(LayoutConstraint layout,
254+
RewriteContext &ctx);
255+
188256
int compare(Atom other, const ProtocolGraph &protos) const;
189257

190258
void dump(llvm::raw_ostream &out) const;
191259

192260
friend bool operator==(Atom lhs, Atom rhs) {
193-
return (lhs.Protos.size() == rhs.Protos.size() &&
194-
std::equal(lhs.Protos.begin(), lhs.Protos.end(),
195-
rhs.Protos.begin()) &&
196-
lhs.Value == rhs.Value);
261+
return lhs.Ptr == rhs.Ptr;
197262
}
198263

199264
friend bool operator!=(Atom lhs, Atom rhs) {
@@ -267,6 +332,32 @@ class Term final {
267332
void dump(llvm::raw_ostream &out) const;
268333
};
269334

335+
/// A global object that can be shared by multiple rewrite systems.
336+
///
337+
/// It stores uniqued atoms and terms.
338+
///
339+
/// Out-of-line methods are documented in RewriteSystem.cpp.
340+
class RewriteContext final {
341+
friend class Atom;
342+
343+
/// Allocator for uniquing atoms and terms.
344+
llvm::BumpPtrAllocator Allocator;
345+
346+
/// Folding set for uniquing atoms.
347+
llvm::FoldingSet<Atom::Storage> Atoms;
348+
349+
RewriteContext(const RewriteContext &) = delete;
350+
RewriteContext(RewriteContext &&) = delete;
351+
RewriteContext &operator=(const RewriteContext &) = delete;
352+
RewriteContext &operator=(RewriteContext &&) = delete;
353+
354+
public:
355+
RewriteContext() {}
356+
357+
Term getTermForType(CanType paramType,
358+
const ProtocolDecl *proto);
359+
};
360+
270361
/// A rewrite rule that replaces occurrences of LHS with RHS.
271362
///
272363
/// LHS must be greater than RHS in the linear order over terms.
@@ -330,6 +421,8 @@ class Rule final {
330421
///
331422
/// Out-of-line methods are documented in RewriteSystem.cpp.
332423
class RewriteSystem final {
424+
RewriteContext &Context;
425+
333426
/// The rules added so far, including rules from our client, as well
334427
/// as rules introduced by the completion procedure.
335428
std::vector<Rule> Rules;
@@ -360,7 +453,7 @@ class RewriteSystem final {
360453
unsigned DebugMerge : 1;
361454

362455
public:
363-
explicit RewriteSystem() {
456+
explicit RewriteSystem(RewriteContext &ctx) : Context(ctx) {
364457
DebugSimplify = false;
365458
DebugAdd = false;
366459
DebugMerge = false;
@@ -371,6 +464,10 @@ class RewriteSystem final {
371464
RewriteSystem &operator=(const RewriteSystem &) = delete;
372465
RewriteSystem &operator=(RewriteSystem &&) = delete;
373466

467+
/// Return the rewrite context used for allocating memory.
468+
RewriteContext &getRewriteContext() const { return Context; }
469+
470+
/// Return the object recording information about known protocols.
374471
const ProtocolGraph &getProtocols() const { return Protos; }
375472

376473
void initialize(std::vector<std::pair<Term, Term>> &&rules,

lib/AST/ProtocolGraph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ void ProtocolGraph::computeInheritedProtocols() {
158158
}
159159
}
160160

161-
/// Recursively compute the 'depth' of e protocol, which is inductively defined
161+
/// Recursively compute the 'depth' of a protocol, which is inductively defined
162162
/// as one greater than the depth of all inherited protocols, with a protocol
163163
/// that does not inherit any other protocol having a depth of one.
164164
unsigned ProtocolGraph::computeProtocolDepth(const ProtocolDecl *proto) {

0 commit comments

Comments
 (0)