Skip to content

Commit ff157e6

Browse files
committed
RequirementMachine: Add notion of 'merged' associated type atoms
1 parent 8aaea2b commit ff157e6

File tree

2 files changed

+55
-19
lines changed

2 files changed

+55
-19
lines changed

include/swift/AST/RewriteSystem.h

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include "swift/AST/Types.h"
2020
#include "llvm/ADT/PointerUnion.h"
2121
#include "llvm/ADT/SmallVector.h"
22+
#include "llvm/ADT/TinyPtrVector.h"
23+
#include <algorithm>
2224

2325
namespace llvm {
2426
class raw_ostream;
@@ -36,38 +38,46 @@ class Atom final {
3638
GenericTypeParamType *,
3739
LayoutConstraint>;
3840

39-
const ProtocolDecl *Proto;
41+
llvm::TinyPtrVector<const ProtocolDecl *> Protos;
4042
Storage Value;
4143

42-
explicit Atom(const ProtocolDecl *proto, Storage value)
43-
: Proto(proto), Value(value) {
44+
explicit Atom(llvm::TinyPtrVector<const ProtocolDecl *> protos,
45+
Storage value)
46+
: Protos(protos), Value(value) {
4447
// Triggers assertion if the atom is not valid.
4548
(void) getKind();
4649
}
4750

4851
public:
4952
static Atom forName(Identifier name) {
50-
return Atom(nullptr, name);
53+
return Atom({}, name);
5154
}
5255

5356
static Atom forProtocol(const ProtocolDecl *proto) {
54-
return Atom(proto, Storage());
57+
return Atom({proto}, Storage());
5558
}
5659

5760
static Atom forAssociatedType(const ProtocolDecl *proto,
5861
Identifier name) {
5962
assert(proto != nullptr);
60-
return Atom(proto, name);
63+
return Atom({proto}, name);
64+
}
65+
66+
static Atom forAssociatedType(
67+
llvm::TinyPtrVector<const ProtocolDecl *>protos,
68+
Identifier name) {
69+
assert(!protos.empty());
70+
return Atom(protos, name);
6171
}
6272

6373
static Atom forGenericParam(GenericTypeParamType *param) {
6474
assert(param->isCanonical());
65-
return Atom(nullptr, param);
75+
return Atom({}, param);
6676
}
6777

6878
static Atom forLayout(LayoutConstraint layout) {
6979
assert(layout->isKnownLayout());
70-
return Atom(nullptr, layout);
80+
return Atom({}, layout);
7181
}
7282

7383
enum class Kind : uint8_t {
@@ -80,23 +90,23 @@ class Atom final {
8090

8191
Kind getKind() const {
8292
if (!Value) {
83-
assert(Proto != nullptr);
93+
assert(Protos.size() == 1);
8494
return Kind::Protocol;
8595
}
8696

8797
if (Value.is<Identifier>()) {
88-
if (Proto != nullptr)
98+
if (!Protos.empty())
8999
return Kind::AssociatedType;
90100
return Kind::Name;
91101
}
92102

93103
if (Value.is<GenericTypeParamType *>()) {
94-
assert(Proto == nullptr);
104+
assert(Protos.empty());
95105
return Kind::GenericParam;
96106
}
97107

98108
if (Value.is<LayoutConstraint>()) {
99-
assert(Proto == nullptr);
109+
assert(Protos.empty());
100110
return Kind::Layout;
101111
}
102112

@@ -110,9 +120,16 @@ class Atom final {
110120
}
111121

112122
const ProtocolDecl *getProtocol() const {
123+
assert(getKind() == Kind::Protocol);
124+
assert(Protos.size() == 1);
125+
return Protos.front();
126+
}
127+
128+
llvm::TinyPtrVector<const ProtocolDecl *> getProtocols() const {
113129
assert(getKind() == Kind::Protocol ||
114130
getKind() == Kind::AssociatedType);
115-
return Proto;
131+
assert(!Protos.empty());
132+
return Protos;
116133
}
117134

118135
GenericTypeParamType *getGenericParam() const {
@@ -130,7 +147,9 @@ class Atom final {
130147
void dump(llvm::raw_ostream &out) const;
131148

132149
friend bool operator==(Atom lhs, Atom rhs) {
133-
return (lhs.Proto == rhs.Proto &&
150+
return (lhs.Protos.size() == rhs.Protos.size() &&
151+
std::equal(lhs.Protos.begin(), lhs.Protos.end(),
152+
rhs.Protos.begin()) &&
134153
lhs.Value == rhs.Value);
135154
}
136155

lib/AST/RewriteSystem.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,17 @@ int Atom::compare(Atom other, ProtocolOrder protocolOrder) const {
3434
return protocolOrder(getProtocol(), other.getProtocol());
3535

3636
case Kind::AssociatedType: {
37-
int result = protocolOrder(getProtocol(), other.getProtocol());
38-
if (result)
39-
return result;
37+
auto protos = getProtocols();
38+
auto otherProtos = other.getProtocols();
39+
40+
if (protos.size() != otherProtos.size())
41+
return otherProtos.size() > protos.size() ? -1 : 1;
42+
43+
for (unsigned i : indices(protos)) {
44+
int result = protocolOrder(protos[i], otherProtos[i]);
45+
if (result)
46+
return result;
47+
}
4048

4149
return getName().compare(other.getName());
4250
}
@@ -73,8 +81,17 @@ void Atom::dump(llvm::raw_ostream &out) const {
7381
return;
7482

7583
case Kind::AssociatedType: {
76-
out << "[" << getProtocol()->getName()
77-
<< ":" << getName() << "]";
84+
out << "[";
85+
bool first = true;
86+
for (const auto *proto : getProtocols()) {
87+
if (first) {
88+
first = false;
89+
} else {
90+
out << "&";
91+
}
92+
out << proto->getName();
93+
}
94+
out << ":" << getName() << "]";
7895
return;
7996
}
8097

0 commit comments

Comments
 (0)