1+ // ===--- ProtocolGraph.cpp - Collect information about protocols ----------===//
2+ //
3+ // This source file is part of the Swift.org open source project
4+ //
5+ // Copyright (c) 2021 Apple Inc. and the Swift project authors
6+ // Licensed under Apache License v2.0 with Runtime Library Exception
7+ //
8+ // See https://swift.org/LICENSE.txt for license information
9+ // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+ //
11+ // ===----------------------------------------------------------------------===//
12+
13+ #include " swift/AST/ProtocolGraph.h"
14+
15+ #include " swift/AST/Decl.h"
16+ #include " swift/AST/Requirement.h"
17+
18+ using namespace swift ;
19+ using namespace rewriting ;
20+
21+ void ProtocolGraph::visitRequirements (ArrayRef<Requirement> reqs) {
22+ for (auto req : reqs) {
23+ if (req.getKind () == RequirementKind::Conformance) {
24+ addProtocol (req.getProtocolDecl ());
25+ }
26+ }
27+ }
28+
29+ const ProtocolInfo &ProtocolGraph::getProtocolInfo (
30+ const ProtocolDecl *proto) const {
31+ auto found = Info.find (proto);
32+ assert (found != Info.end ());
33+ return found->second ;
34+ }
35+
36+ void ProtocolGraph::addProtocol (const ProtocolDecl *proto) {
37+ if (Info.count (proto) > 0 )
38+ return ;
39+
40+ Info[proto] = {proto->getInheritedProtocols (),
41+ proto->getAssociatedTypeMembers (),
42+ proto->getRequirementSignature ()};
43+ Protocols.push_back (proto);
44+ }
45+
46+ void ProtocolGraph::computeTransitiveClosure () {
47+ unsigned i = 0 ;
48+ while (i < Protocols.size ()) {
49+ auto *proto = Protocols[i++];
50+ visitRequirements (getProtocolInfo (proto).Requirements );
51+ }
52+ }
53+
54+ void ProtocolGraph::computeLinearOrder () {
55+ for (const auto *proto : Protocols) {
56+ (void ) computeProtocolDepth (proto);
57+ }
58+
59+ std::sort (
60+ Protocols.begin (), Protocols.end (),
61+ [&](const ProtocolDecl *lhs,
62+ const ProtocolDecl *rhs) -> bool {
63+ const auto &lhsInfo = getProtocolInfo (lhs);
64+ const auto &rhsInfo = getProtocolInfo (rhs);
65+
66+ // protocol Base {} // depth 1
67+ // protocol Derived : Base {} // depth 2
68+ //
69+ // Derived < Base in the linear order.
70+ if (lhsInfo.Depth != rhsInfo.Depth )
71+ return lhsInfo.Depth > rhsInfo.Depth ;
72+
73+ return TypeDecl::compare (lhs, rhs) < 0 ;
74+ });
75+
76+ for (unsigned i : indices (Protocols)) {
77+ Info[Protocols[i]].Index = i;
78+ }
79+
80+ if (Debug) {
81+ for (const auto *proto : Protocols) {
82+ const auto &info = getProtocolInfo (proto);
83+ llvm::dbgs () << " @ Protocol " << proto->getName ()
84+ << " Depth=" << info.Depth
85+ << " Index=" << info.Index << " \n " ;
86+ }
87+ }
88+ }
89+
90+ void ProtocolGraph::computeInheritedAssociatedTypes () {
91+ for (const auto *proto : Protocols) {
92+ auto &info = Info[proto];
93+
94+ llvm::SmallDenseSet<const AssociatedTypeDecl *, 4 > visited;
95+ for (const auto *inherited : info.Inherited ) {
96+ if (inherited == proto)
97+ continue ;
98+
99+ for (auto *inheritedType : getProtocolInfo (inherited).AssociatedTypes ) {
100+ if (!visited.insert (inheritedType).second )
101+ continue ;
102+
103+ // The 'if (inherited == proto)' above avoids a potential
104+ // iterator invalidation here.
105+ info.AssociatedTypes .push_back (inheritedType);
106+ }
107+ }
108+ }
109+ }
110+
111+ unsigned ProtocolGraph::computeProtocolDepth (const ProtocolDecl *proto) {
112+ auto &info = Info[proto];
113+
114+ if (info.Mark ) {
115+ // Already computed, or we have a cycle. Cycles are diagnosed
116+ // elsewhere in the type checker, so we don't have to do
117+ // anything here.
118+ return info.Depth ;
119+ }
120+
121+ info.Mark = true ;
122+ unsigned depth = 0 ;
123+
124+ for (auto *inherited : info.Inherited ) {
125+ unsigned inheritedDepth = computeProtocolDepth (inherited);
126+ depth = std::max (inheritedDepth, depth);
127+ }
128+
129+ depth++;
130+
131+ info.Depth = depth;
132+ return depth;
133+ }
0 commit comments