Skip to content

Commit 3e47d94

Browse files
committed
RequirementMachine: Implement GenericSignature::getLocalRequirements() query
1 parent c2a275e commit 3e47d94

File tree

3 files changed

+165
-31
lines changed

3 files changed

+165
-31
lines changed

include/swift/AST/RequirementMachine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class RequirementMachine final {
5858
// Generic signature queries. Generally you shouldn't have to construct a
5959
// RequirementMachine instance; instead, call the corresponding methods on
6060
// GenericSignature, which lazily create a RequirementMachine for you.
61+
GenericSignature::LocalRequirements getLocalRequirements(Type depType,
62+
TypeArrayView<GenericTypeParamType> genericParams) const;
6163
bool requiresClass(Type depType) const;
6264
LayoutConstraint getLayoutConstraint(Type depType) const;
6365
bool requiresProtocol(Type depType, const ProtocolDecl *proto) const;

lib/AST/GenericSignature.cpp

Lines changed: 124 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -258,47 +258,134 @@ GenericSignature::LocalRequirements
258258
GenericSignatureImpl::getLocalRequirements(Type depType) const {
259259
assert(depType->isTypeParameter() && "Expected a type parameter here");
260260

261-
GenericSignature::LocalRequirements result;
261+
auto computeViaGSB = [&]() {
262+
GenericSignature::LocalRequirements result;
262263

263-
auto &builder = *getGenericSignatureBuilder();
264+
auto &builder = *getGenericSignatureBuilder();
264265

265-
auto resolved =
266-
builder.maybeResolveEquivalenceClass(
267-
depType,
268-
ArchetypeResolutionKind::CompleteWellFormed,
269-
/*wantExactPotentialArchetype=*/false);
270-
if (!resolved) {
271-
result.concreteType = ErrorType::get(depType);
272-
return result;
273-
}
266+
auto resolved =
267+
builder.maybeResolveEquivalenceClass(
268+
depType,
269+
ArchetypeResolutionKind::CompleteWellFormed,
270+
/*wantExactPotentialArchetype=*/false);
271+
if (!resolved) {
272+
result.concreteType = ErrorType::get(depType);
273+
return result;
274+
}
274275

275-
if (auto concreteType = resolved.getAsConcreteType()) {
276-
result.concreteType = concreteType;
277-
return result;
278-
}
276+
if (auto concreteType = resolved.getAsConcreteType()) {
277+
result.concreteType = concreteType;
278+
return result;
279+
}
279280

280-
auto *equivClass = resolved.getEquivalenceClass(builder);
281+
auto *equivClass = resolved.getEquivalenceClass(builder);
281282

282-
auto genericParams = getGenericParams();
283-
result.anchor = equivClass->getAnchor(builder, genericParams);
283+
auto genericParams = getGenericParams();
284+
result.anchor = equivClass->getAnchor(builder, genericParams);
285+
286+
if (equivClass->concreteType) {
287+
result.concreteType = equivClass->concreteType;
288+
return result;
289+
}
290+
291+
result.superclass = equivClass->superclass;
292+
293+
for (const auto &conforms : equivClass->conformsTo) {
294+
auto proto = conforms.first;
295+
296+
if (!equivClass->isConformanceSatisfiedBySuperclass(proto))
297+
result.protos.push_back(proto);
298+
}
299+
300+
result.layout = equivClass->layout;
284301

285-
if (equivClass->concreteType) {
286-
result.concreteType = equivClass->concreteType;
287302
return result;
288-
}
303+
};
289304

290-
result.superclass = equivClass->superclass;
305+
auto computeViaRQM = [&]() {
306+
auto *machine = getRequirementMachine();
307+
return machine->getLocalRequirements(depType, getGenericParams());
308+
};
291309

292-
for (const auto &conforms : equivClass->conformsTo) {
293-
auto proto = conforms.first;
310+
auto &ctx = getASTContext();
311+
if (ctx.LangOpts.EnableRequirementMachine) {
312+
auto rqmResult = computeViaRQM();
294313

295-
if (!equivClass->isConformanceSatisfiedBySuperclass(proto))
296-
result.protos.push_back(proto);
297-
}
314+
#ifndef NDEBUG
315+
auto gsbResult = computeViaGSB();
298316

299-
result.layout = equivClass->layout;
317+
auto typesEqual = [&](Type lhs, Type rhs) {
318+
if (!lhs || !rhs)
319+
return !lhs == !rhs;
320+
return lhs->isEqual(rhs);
321+
};
300322

301-
return result;
323+
auto compare = [&]() {
324+
// If the types are concrete, we don't care about the rest.
325+
if (gsbResult.concreteType || rqmResult.concreteType) {
326+
if (!typesEqual(gsbResult.concreteType, rqmResult.concreteType))
327+
return false;
328+
329+
return true;
330+
}
331+
332+
if (!typesEqual(gsbResult.anchor, rqmResult.anchor))
333+
return false;
334+
335+
if (gsbResult.layout != rqmResult.layout)
336+
return false;
337+
338+
auto lhsProtos = gsbResult.protos;
339+
ProtocolType::canonicalizeProtocols(lhsProtos);
340+
auto rhsProtos = rqmResult.protos;
341+
ProtocolType::canonicalizeProtocols(rhsProtos);
342+
343+
if (lhsProtos != rhsProtos)
344+
return false;
345+
346+
return true;
347+
};
348+
349+
auto dumpReqs = [&](const GenericSignature::LocalRequirements &reqs) {
350+
if (reqs.anchor) {
351+
llvm::errs() << "- Anchor: " << reqs.anchor << "\n";
352+
reqs.anchor.dump(llvm::errs());
353+
}
354+
if (reqs.concreteType) {
355+
llvm::errs() << "- Concrete type: " << reqs.concreteType << "\n";
356+
reqs.concreteType.dump(llvm::errs());
357+
}
358+
if (reqs.superclass) {
359+
llvm::errs() << "- Superclass: " << reqs.superclass << "\n";
360+
reqs.superclass.dump(llvm::errs());
361+
}
362+
if (reqs.layout) {
363+
llvm::errs() << "- Layout: " << reqs.layout << "\n";
364+
}
365+
for (const auto *proto : reqs.protos) {
366+
llvm::errs() << "- Conforms to: " << proto->getName() << "\n";
367+
}
368+
};
369+
370+
if (!compare()) {
371+
llvm::errs() << "RequirementMachine::getLocalRequirements() is broken\n";
372+
llvm::errs() << "Generic signature: " << GenericSignature(this) << "\n";
373+
llvm::errs() << "Dependent type: "; depType.dump(llvm::errs());
374+
llvm::errs() << "GenericSignatureBuilder says:\n";
375+
dumpReqs(gsbResult);
376+
llvm::errs() << "\n";
377+
llvm::errs() << "RequirementMachine says:\n";
378+
dumpReqs(rqmResult);
379+
llvm::errs() << "\n";
380+
getRequirementMachine()->dump(llvm::errs());
381+
abort();
382+
}
383+
#endif
384+
385+
return rqmResult;
386+
} else {
387+
return computeViaGSB();
388+
}
302389
}
303390

304391
ASTContext &GenericSignatureImpl::getASTContext() const {
@@ -1049,10 +1136,16 @@ GenericSignatureImpl::lookupNestedType(Type type, Identifier name) const {
10491136
llvm::errs() << "Generic signature: " << GenericSignature(this) << "\n";
10501137
llvm::errs() << "Dependent type: "; type.dump(llvm::errs());
10511138
llvm::errs() << "GenericSignatureBuilder says: ";
1052-
gsbResult->dumpRef(llvm::errs());
1139+
if (gsbResult)
1140+
gsbResult->dumpRef(llvm::errs());
1141+
else
1142+
llvm::errs() << "<nullptr>";
10531143
llvm::errs() << "\n";
10541144
llvm::errs() << "RequirementMachine says: ";
1055-
rqmResult->dumpRef(llvm::errs());
1145+
if (rqmResult)
1146+
rqmResult->dumpRef(llvm::errs());
1147+
else
1148+
llvm::errs() << "<nullptr>";
10561149
llvm::errs() << "\n";
10571150
getRequirementMachine()->dump(llvm::errs());
10581151
abort();

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,45 @@ void RequirementMachine::dump(llvm::raw_ostream &out) const {
445445
Impl->dump(out);
446446
}
447447

448+
/// Collects all requirements on a type parameter that are used to construct
449+
/// its ArchetypeType in a GenericEnvironment.
450+
GenericSignature::LocalRequirements
451+
RequirementMachine::getLocalRequirements(
452+
Type depType,
453+
TypeArrayView<GenericTypeParamType> genericParams) const {
454+
auto term = Impl->Context.getMutableTermForType(depType->getCanonicalType(),
455+
/*proto=*/nullptr);
456+
Impl->System.simplify(term);
457+
Impl->verify(term);
458+
459+
auto &protos = Impl->System.getProtocols();
460+
461+
GenericSignature::LocalRequirements result;
462+
result.anchor = Impl->Context.getTypeForTerm(term, genericParams, protos);
463+
464+
auto *equivClass = Impl->Map.lookUpEquivalenceClass(term);
465+
if (!equivClass)
466+
return result;
467+
468+
if (equivClass->isConcreteType()) {
469+
result.concreteType = equivClass->getConcreteType({}, protos,
470+
Impl->Context);
471+
return result;
472+
}
473+
474+
if (equivClass->hasSuperclassBound()) {
475+
result.superclass = equivClass->getSuperclassBound({}, protos,
476+
Impl->Context);
477+
}
478+
479+
for (const auto *proto : equivClass->getConformsTo())
480+
result.protos.push_back(const_cast<ProtocolDecl *>(proto));
481+
482+
result.layout = equivClass->getLayoutConstraint();
483+
484+
return result;
485+
}
486+
448487
bool RequirementMachine::requiresClass(Type depType) const {
449488
auto term = Impl->Context.getMutableTermForType(depType->getCanonicalType(),
450489
/*proto=*/nullptr);

0 commit comments

Comments
 (0)