Skip to content

Commit 85d17a6

Browse files
committed
RequirementMachine: Split off ProtocolGraph into its own file
1 parent ff157e6 commit 85d17a6

File tree

6 files changed

+225
-156
lines changed

6 files changed

+225
-156
lines changed

include/swift/AST/LayoutConstraintKind.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616

17+
#include "llvm/Support/DataTypes.h"
18+
1719
#ifndef SWIFT_LAYOUT_CONSTRAINTKIND_H
1820
#define SWIFT_LAYOUT_CONSTRAINTKIND_H
1921

include/swift/AST/ProtocolGraph.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===--- ProtocolGraph.h - Collects information about protocols -*- C++ -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2021 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_PROTOCOLGRAPH_H
14+
#define SWIFT_PROTOCOLGRAPH_H
15+
16+
#include "swift/AST/Requirement.h"
17+
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/DenseMap.h"
19+
#include "llvm/ADT/TinyPtrVector.h"
20+
21+
namespace swift {
22+
23+
class ProtocolDecl;
24+
class AssociatedTypeDecl;
25+
26+
namespace rewriting {
27+
28+
struct ProtocolInfo {
29+
ArrayRef<ProtocolDecl *> Inherited;
30+
llvm::TinyPtrVector<AssociatedTypeDecl *> AssociatedTypes;
31+
ArrayRef<Requirement> Requirements;
32+
33+
// Used by computeDepth() to detect circularity.
34+
unsigned Mark : 1;
35+
36+
// Longest chain of protocol refinements, including this one.
37+
// Greater than zero on valid code, might be zero if there's
38+
// a cycle.
39+
unsigned Depth : 31;
40+
41+
// Index of the protocol in the linear order.
42+
unsigned Index : 32;
43+
44+
ProtocolInfo() {
45+
Mark = 0;
46+
Depth = 0;
47+
Index = 0;
48+
}
49+
50+
ProtocolInfo(ArrayRef<ProtocolDecl *> inherited,
51+
llvm::TinyPtrVector<AssociatedTypeDecl *> &&types,
52+
ArrayRef<Requirement> reqs)
53+
: Inherited(inherited),
54+
AssociatedTypes(types),
55+
Requirements(reqs) {
56+
Mark = 0;
57+
Depth = 0;
58+
Index = 0;
59+
}
60+
};
61+
62+
struct ProtocolGraph {
63+
llvm::DenseMap<const ProtocolDecl *, ProtocolInfo> Info;
64+
std::vector<const ProtocolDecl *> Protocols;
65+
bool Debug = false;
66+
67+
void visitRequirements(ArrayRef<Requirement> reqs);
68+
69+
const ProtocolInfo &getProtocolInfo(
70+
const ProtocolDecl *proto) const;
71+
72+
void addProtocol(const ProtocolDecl *proto);
73+
74+
void computeTransitiveClosure();
75+
76+
void computeLinearOrder();
77+
78+
void computeInheritedAssociatedTypes();
79+
80+
private:
81+
unsigned computeProtocolDepth(const ProtocolDecl *proto);
82+
};
83+
84+
} // end namespace rewriting
85+
86+
} // end namespace swift
87+
#endif

lib/AST/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ add_swift_host_library(swiftAST STATIC
7171
PlatformKind.cpp
7272
PrettyStackTrace.cpp
7373
ProtocolConformance.cpp
74+
ProtocolGraph.cpp
7475
RawComment.cpp
7576
RequirementEnvironment.cpp
7677
RequirementMachine.cpp

lib/AST/ProtocolGraph.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
//===--- ProtocolGraph.cpp - Collect information about protocols ----------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2021 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "swift/AST/ProtocolGraph.h"
14+
15+
#include "swift/AST/Decl.h"
16+
#include "swift/AST/Requirement.h"
17+
18+
using namespace swift;
19+
using namespace rewriting;
20+
21+
void ProtocolGraph::visitRequirements(ArrayRef<Requirement> reqs) {
22+
for (auto req : reqs) {
23+
if (req.getKind() == RequirementKind::Conformance) {
24+
addProtocol(req.getProtocolDecl());
25+
}
26+
}
27+
}
28+
29+
const ProtocolInfo &ProtocolGraph::getProtocolInfo(
30+
const ProtocolDecl *proto) const {
31+
auto found = Info.find(proto);
32+
assert(found != Info.end());
33+
return found->second;
34+
}
35+
36+
void ProtocolGraph::addProtocol(const ProtocolDecl *proto) {
37+
if (Info.count(proto) > 0)
38+
return;
39+
40+
Info[proto] = {proto->getInheritedProtocols(),
41+
proto->getAssociatedTypeMembers(),
42+
proto->getRequirementSignature()};
43+
Protocols.push_back(proto);
44+
}
45+
46+
void ProtocolGraph::computeTransitiveClosure() {
47+
unsigned i = 0;
48+
while (i < Protocols.size()) {
49+
auto *proto = Protocols[i++];
50+
visitRequirements(getProtocolInfo(proto).Requirements);
51+
}
52+
}
53+
54+
void ProtocolGraph::computeLinearOrder() {
55+
for (const auto *proto : Protocols) {
56+
(void) computeProtocolDepth(proto);
57+
}
58+
59+
std::sort(
60+
Protocols.begin(), Protocols.end(),
61+
[&](const ProtocolDecl *lhs,
62+
const ProtocolDecl *rhs) -> bool {
63+
const auto &lhsInfo = getProtocolInfo(lhs);
64+
const auto &rhsInfo = getProtocolInfo(rhs);
65+
66+
// protocol Base {} // depth 1
67+
// protocol Derived : Base {} // depth 2
68+
//
69+
// Derived < Base in the linear order.
70+
if (lhsInfo.Depth != rhsInfo.Depth)
71+
return lhsInfo.Depth > rhsInfo.Depth;
72+
73+
return TypeDecl::compare(lhs, rhs) < 0;
74+
});
75+
76+
for (unsigned i : indices(Protocols)) {
77+
Info[Protocols[i]].Index = i;
78+
}
79+
80+
if (Debug) {
81+
for (const auto *proto : Protocols) {
82+
const auto &info = getProtocolInfo(proto);
83+
llvm::dbgs() << "@ Protocol " << proto->getName()
84+
<< " Depth=" << info.Depth
85+
<< " Index=" << info.Index << "\n";
86+
}
87+
}
88+
}
89+
90+
void ProtocolGraph::computeInheritedAssociatedTypes() {
91+
for (const auto *proto : Protocols) {
92+
auto &info = Info[proto];
93+
94+
llvm::SmallDenseSet<const AssociatedTypeDecl *, 4> visited;
95+
for (const auto *inherited : info.Inherited) {
96+
if (inherited == proto)
97+
continue;
98+
99+
for (auto *inheritedType : getProtocolInfo(inherited).AssociatedTypes) {
100+
if (!visited.insert(inheritedType).second)
101+
continue;
102+
103+
// The 'if (inherited == proto)' above avoids a potential
104+
// iterator invalidation here.
105+
info.AssociatedTypes.push_back(inheritedType);
106+
}
107+
}
108+
}
109+
}
110+
111+
unsigned ProtocolGraph::computeProtocolDepth(const ProtocolDecl *proto) {
112+
auto &info = Info[proto];
113+
114+
if (info.Mark) {
115+
// Already computed, or we have a cycle. Cycles are diagnosed
116+
// elsewhere in the type checker, so we don't have to do
117+
// anything here.
118+
return info.Depth;
119+
}
120+
121+
info.Mark = true;
122+
unsigned depth = 0;
123+
124+
for (auto *inherited : info.Inherited) {
125+
unsigned inheritedDepth = computeProtocolDepth(inherited);
126+
depth = std::max(inheritedDepth, depth);
127+
}
128+
129+
depth++;
130+
131+
info.Depth = depth;
132+
return depth;
133+
}

0 commit comments

Comments
 (0)