Skip to content

Commit af443be

Browse files
committed
RequirementMachine: Implement RewriteContext::getRequirementMachine(ProtocolDecl *)
1 parent aaf84ac commit af443be

File tree

2 files changed

+119
-17
lines changed

2 files changed

+119
-17
lines changed

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ RequirementMachine *RewriteContext::getRequirementMachine(
133133
auto *newMachine = new rewriting::RequirementMachine(*this);
134134
machine = newMachine;
135135

136-
// This might re-entrantly invalidate 'machine', which is a reference
137-
// into Protos.
136+
// This might re-entrantly invalidate 'machine'.
138137
auto status = newMachine->initWithGenericSignature(sig);
139138
newMachine->checkCompletionResult(status.first);
140139

@@ -170,7 +169,10 @@ void RewriteContext::getProtocolComponentRec(
170169
stack.push_back(proto);
171170

172171
// Look at each successor.
173-
for (auto *depProto : proto->getProtocolDependencies()) {
172+
auto found = Dependencies.find(proto);
173+
assert(found != Dependencies.end());
174+
175+
for (auto *depProto : found->second) {
174176
auto found = Protos.find(depProto);
175177
if (found == Protos.end()) {
176178
// Successor has not yet been visited. Recurse.
@@ -223,40 +225,121 @@ void RewriteContext::getProtocolComponentRec(
223225
}
224226
}
225227

226-
/// Lazily construct a requirement machine for the given protocol's strongly
227-
/// connected component (SCC) in the protocol dependency graph.
228+
/// Get the strongly connected component (SCC) of the protocol dependency
229+
/// graph containing the given protocol.
228230
///
229-
/// This can only be called once, to prevent multiple requirement machines
230-
/// for being built with the same component.
231-
ArrayRef<const ProtocolDecl *> RewriteContext::getProtocolComponent(
232-
const ProtocolDecl *proto) {
231+
/// You must not hold on to this reference across calls to any other
232+
/// Requirement Machine operations, since they might insert new entries
233+
/// into the underlying DenseMap, invalidating the reference.
234+
RewriteContext::ProtocolComponent &
235+
RewriteContext::getProtocolComponentImpl(const ProtocolDecl *proto) {
236+
{
237+
// We pre-load protocol dependencies into the Dependencies map
238+
// because getProtocolDependencies() can trigger recursive calls into
239+
// the requirement machine in highly-invalid code, which violates
240+
// invariants in getProtocolComponentRec().
241+
SmallVector<const ProtocolDecl *, 3> worklist;
242+
worklist.push_back(proto);
243+
244+
while (!worklist.empty()) {
245+
const auto *otherProto = worklist.back();
246+
worklist.pop_back();
247+
248+
auto found = Dependencies.find(otherProto);
249+
if (found != Dependencies.end())
250+
continue;
251+
252+
auto protoDeps = otherProto->getProtocolDependencies();
253+
Dependencies.insert(std::make_pair(otherProto, protoDeps));
254+
for (auto *nextProto : protoDeps)
255+
worklist.push_back(nextProto);
256+
}
257+
}
258+
233259
auto found = Protos.find(proto);
234260
if (found == Protos.end()) {
261+
if (ProtectProtocolComponentRec) {
262+
llvm::errs() << "Too much recursion is bad\n";
263+
abort();
264+
}
265+
266+
ProtectProtocolComponentRec = true;
267+
235268
SmallVector<const ProtocolDecl *, 3> stack;
236269
getProtocolComponentRec(proto, stack);
237270
assert(stack.empty());
238271

239272
found = Protos.find(proto);
240273
assert(found != Protos.end());
274+
275+
ProtectProtocolComponentRec = false;
241276
}
242277

243278
assert(Components.count(found->second.ComponentID) != 0);
244279
auto &component = Components[found->second.ComponentID];
245280

246-
if (component.InProgress) {
247-
llvm::errs() << "Re-entrant construction of requirement "
248-
<< "machine for:";
281+
assert(std::find(component.Protos.begin(), component.Protos.end(), proto)
282+
!= component.Protos.end() && "Protocol is in the wrong SCC");
283+
return component;
284+
}
285+
286+
/// Get the list of protocols in the strongly connected component (SCC)
287+
/// of the protocol dependency graph containing the given protocol.
288+
///
289+
/// This can only be called once, to prevent multiple requirement machines
290+
/// for being built with the same component.
291+
ArrayRef<const ProtocolDecl *> RewriteContext::getProtocolComponent(
292+
const ProtocolDecl *proto) {
293+
auto &component = getProtocolComponentImpl(proto);
294+
295+
if (component.ComputingRequirementSignatures) {
296+
llvm::errs() << "Re-entrant minimization of requirement signatures for: ";
249297
for (auto *proto : component.Protos)
250298
llvm::errs() << " " << proto->getName();
251299
llvm::errs() << "\n";
252300
abort();
253301
}
254302

255-
component.InProgress = true;
303+
component.ComputingRequirementSignatures = true;
256304

257305
return component.Protos;
258306
}
259307

308+
/// Get the list of protocols in the strongly connected component (SCC)
309+
/// of the protocol dependency graph containing the given protocol.
310+
///
311+
/// This can only be called once, to prevent multiple requirement machines
312+
/// for being built with the same component.
313+
RequirementMachine *RewriteContext::getRequirementMachine(
314+
const ProtocolDecl *proto) {
315+
auto &component = getProtocolComponentImpl(proto);
316+
317+
if (component.Machine) {
318+
if (!component.Machine->isComplete()) {
319+
llvm::errs() << "Re-entrant construction of requirement machine for: ";
320+
for (auto *proto : component.Protos)
321+
llvm::errs() << " " << proto->getName();
322+
llvm::errs() << "\n";
323+
abort();
324+
}
325+
326+
return component.Machine;
327+
}
328+
329+
// Store this requirement machine before adding the protocols, to catch
330+
// re-entrant construction via initWithProtocolSignatureRequirements()
331+
// below.
332+
auto *newMachine = new rewriting::RequirementMachine(*this);
333+
component.Machine = newMachine;
334+
335+
// This might re-entrantly invalidate 'component.Machine'.
336+
auto status = newMachine->initWithProtocolSignatureRequirements(
337+
component.Protos);
338+
newMachine->checkCompletionResult(status.first);
339+
340+
return newMachine;
341+
}
342+
260343
bool RewriteContext::isRecursivelyConstructingRequirementMachine(
261344
const ProtocolDecl *proto) {
262345
if (proto->isRequirementSignatureComputed())
@@ -270,12 +353,17 @@ bool RewriteContext::isRecursivelyConstructingRequirementMachine(
270353
if (component == Components.end())
271354
return false;
272355

273-
return component->second.InProgress;
356+
return component->second.ComputingRequirementSignatures;
274357
}
275358

276359
/// We print stats in the destructor, which should get executed at the end of
277360
/// a compilation job.
278361
RewriteContext::~RewriteContext() {
362+
for (const auto &pair : Components)
363+
delete pair.second.Machine;
364+
365+
Components.clear();
366+
279367
for (const auto &pair : Machines)
280368
delete pair.second;
281369

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,28 @@ class RewriteContext final {
8686
/// The members of this connected component.
8787
ArrayRef<const ProtocolDecl *> Protos;
8888

89-
/// Each connected component has a lazily-created requirement machine.
90-
bool InProgress = false;
89+
/// Whether we are currently computing the requirement signatures of
90+
/// the protocols in this component.
91+
bool ComputingRequirementSignatures = false;
92+
93+
/// Each connected component has a lazily-created requirement machine
94+
/// built from the requirement signatures of the protocols in this
95+
/// component.
96+
RequirementMachine *Machine = nullptr;
9197
};
9298

93-
/// The protocol dependency graph.
99+
/// We pre-load protocol dependencies here to avoid re-entrancy.
100+
llvm::DenseMap<const ProtocolDecl *, ArrayRef<ProtocolDecl *>> Dependencies;
101+
102+
/// Maps protocols to their connected components.
94103
llvm::DenseMap<const ProtocolDecl *, ProtocolNode> Protos;
95104

96105
/// Used by Tarjan's algorithm.
97106
unsigned NextComponentIndex = 0;
98107

108+
/// Prevents re-entrant calls into getProtocolComponentRec().
109+
bool ProtectProtocolComponentRec = false;
110+
99111
/// The connected components. Keys are the ComponentID fields of
100112
/// ProtocolNode.
101113
llvm::DenseMap<unsigned, ProtocolComponent> Components;
@@ -111,6 +123,7 @@ class RewriteContext final {
111123

112124
void getProtocolComponentRec(const ProtocolDecl *proto,
113125
SmallVectorImpl<const ProtocolDecl *> &stack);
126+
ProtocolComponent &getProtocolComponentImpl(const ProtocolDecl *proto);
114127

115128
public:
116129
/// Statistics.
@@ -182,6 +195,7 @@ class RewriteContext final {
182195
bool isRecursivelyConstructingRequirementMachine(CanGenericSignature sig);
183196

184197
ArrayRef<const ProtocolDecl *> getProtocolComponent(const ProtocolDecl *proto);
198+
RequirementMachine *getRequirementMachine(const ProtocolDecl *proto);
185199
bool isRecursivelyConstructingRequirementMachine(const ProtocolDecl *proto);
186200

187201
~RewriteContext();

0 commit comments

Comments
 (0)