1
+ // ===--- ProtocolGraph.cpp - Collect information about protocols ----------===//
2
+ //
3
+ // This source file is part of the Swift.org open source project
4
+ //
5
+ // Copyright (c) 2021 Apple Inc. and the Swift project authors
6
+ // Licensed under Apache License v2.0 with Runtime Library Exception
7
+ //
8
+ // See https://swift.org/LICENSE.txt for license information
9
+ // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
+ //
11
+ // ===----------------------------------------------------------------------===//
12
+
13
+ #include " swift/AST/ProtocolGraph.h"
14
+
15
+ #include " swift/AST/Decl.h"
16
+ #include " swift/AST/Requirement.h"
17
+
18
+ using namespace swift ;
19
+ using namespace rewriting ;
20
+
21
+ void ProtocolGraph::visitRequirements (ArrayRef<Requirement> reqs) {
22
+ for (auto req : reqs) {
23
+ if (req.getKind () == RequirementKind::Conformance) {
24
+ addProtocol (req.getProtocolDecl ());
25
+ }
26
+ }
27
+ }
28
+
29
+ const ProtocolInfo &ProtocolGraph::getProtocolInfo (
30
+ const ProtocolDecl *proto) const {
31
+ auto found = Info.find (proto);
32
+ assert (found != Info.end ());
33
+ return found->second ;
34
+ }
35
+
36
+ void ProtocolGraph::addProtocol (const ProtocolDecl *proto) {
37
+ if (Info.count (proto) > 0 )
38
+ return ;
39
+
40
+ Info[proto] = {proto->getInheritedProtocols (),
41
+ proto->getAssociatedTypeMembers (),
42
+ proto->getRequirementSignature ()};
43
+ Protocols.push_back (proto);
44
+ }
45
+
46
+ void ProtocolGraph::computeTransitiveClosure () {
47
+ unsigned i = 0 ;
48
+ while (i < Protocols.size ()) {
49
+ auto *proto = Protocols[i++];
50
+ visitRequirements (getProtocolInfo (proto).Requirements );
51
+ }
52
+ }
53
+
54
+ void ProtocolGraph::computeLinearOrder () {
55
+ for (const auto *proto : Protocols) {
56
+ (void ) computeProtocolDepth (proto);
57
+ }
58
+
59
+ std::sort (
60
+ Protocols.begin (), Protocols.end (),
61
+ [&](const ProtocolDecl *lhs,
62
+ const ProtocolDecl *rhs) -> bool {
63
+ const auto &lhsInfo = getProtocolInfo (lhs);
64
+ const auto &rhsInfo = getProtocolInfo (rhs);
65
+
66
+ // protocol Base {} // depth 1
67
+ // protocol Derived : Base {} // depth 2
68
+ //
69
+ // Derived < Base in the linear order.
70
+ if (lhsInfo.Depth != rhsInfo.Depth )
71
+ return lhsInfo.Depth > rhsInfo.Depth ;
72
+
73
+ return TypeDecl::compare (lhs, rhs) < 0 ;
74
+ });
75
+
76
+ for (unsigned i : indices (Protocols)) {
77
+ Info[Protocols[i]].Index = i;
78
+ }
79
+
80
+ if (Debug) {
81
+ for (const auto *proto : Protocols) {
82
+ const auto &info = getProtocolInfo (proto);
83
+ llvm::dbgs () << " @ Protocol " << proto->getName ()
84
+ << " Depth=" << info.Depth
85
+ << " Index=" << info.Index << " \n " ;
86
+ }
87
+ }
88
+ }
89
+
90
+ void ProtocolGraph::computeInheritedAssociatedTypes () {
91
+ for (const auto *proto : Protocols) {
92
+ auto &info = Info[proto];
93
+
94
+ llvm::SmallDenseSet<const AssociatedTypeDecl *, 4 > visited;
95
+ for (const auto *inherited : info.Inherited ) {
96
+ if (inherited == proto)
97
+ continue ;
98
+
99
+ for (auto *inheritedType : getProtocolInfo (inherited).AssociatedTypes ) {
100
+ if (!visited.insert (inheritedType).second )
101
+ continue ;
102
+
103
+ // The 'if (inherited == proto)' above avoids a potential
104
+ // iterator invalidation here.
105
+ info.AssociatedTypes .push_back (inheritedType);
106
+ }
107
+ }
108
+ }
109
+ }
110
+
111
+ unsigned ProtocolGraph::computeProtocolDepth (const ProtocolDecl *proto) {
112
+ auto &info = Info[proto];
113
+
114
+ if (info.Mark ) {
115
+ // Already computed, or we have a cycle. Cycles are diagnosed
116
+ // elsewhere in the type checker, so we don't have to do
117
+ // anything here.
118
+ return info.Depth ;
119
+ }
120
+
121
+ info.Mark = true ;
122
+ unsigned depth = 0 ;
123
+
124
+ for (auto *inherited : info.Inherited ) {
125
+ unsigned inheritedDepth = computeProtocolDepth (inherited);
126
+ depth = std::max (inheritedDepth, depth);
127
+ }
128
+
129
+ depth++;
130
+
131
+ info.Depth = depth;
132
+ return depth;
133
+ }
0 commit comments