Skip to content

Commit ac3dcb1

Browse files
authored
Merge pull request swiftlang#27067 from CodaFi/requesting-a-clean-sweep
[Evaluator Ergonomics] Shuffle Evaluator Infrastructure Out Of Requests
2 parents 6b158e0 + bea0a6a commit ac3dcb1

File tree

9 files changed

+90
-133
lines changed

9 files changed

+90
-133
lines changed

include/swift/AST/TypeCheckRequests.h

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -370,10 +370,10 @@ struct WhereClauseOwner {
370370
WhereClauseOwner(Decl *decl);
371371

372372
WhereClauseOwner(DeclContext *dc, GenericParamList *genericParams)
373-
: dc(dc), source(genericParams) { }
373+
: dc(dc), source(genericParams) {}
374374

375375
WhereClauseOwner(DeclContext *dc, SpecializeAttr *attr)
376-
: dc(dc), source(attr) { }
376+
: dc(dc), source(attr) {}
377377

378378
SourceLoc getLoc() const;
379379

@@ -392,6 +392,19 @@ struct WhereClauseOwner {
392392
const WhereClauseOwner &rhs) {
393393
return !(lhs == rhs);
394394
}
395+
396+
public:
397+
/// Retrieve the array of requirements.
398+
MutableArrayRef<RequirementRepr> getRequirements() const;
399+
400+
/// Visit each of the requirements,
401+
///
402+
/// \returns true after short-circuiting if the callback returned \c true
403+
/// for any of the requirements.
404+
bool
405+
visitRequirements(TypeResolutionStage stage,
406+
llvm::function_ref<bool(Requirement, RequirementRepr *)>
407+
callback) const &&;
395408
};
396409

397410
void simple_display(llvm::raw_ostream &out, const WhereClauseOwner &owner);
@@ -405,17 +418,6 @@ class RequirementRequest :
405418
public:
406419
using SimpleRequest::SimpleRequest;
407420

408-
/// Retrieve the array of requirements from the given owner.
409-
static MutableArrayRef<RequirementRepr> getRequirements(WhereClauseOwner);
410-
411-
/// Visit each of the requirements in the given owner,
412-
///
413-
/// \returns true after short-circuiting if the callback returned \c true
414-
/// for any of the requirements.
415-
static bool visitRequirements(
416-
WhereClauseOwner, TypeResolutionStage stage,
417-
llvm::function_ref<bool(Requirement, RequirementRepr*)> callback);
418-
419421
private:
420422
friend SimpleRequest;
421423

@@ -501,20 +503,6 @@ class DefaultTypeRequest
501503
bool isCached() const { return true; }
502504
Optional<Type> getCachedResult() const;
503505
void cacheResult(Type value) const;
504-
505-
private:
506-
KnownProtocolKind getKnownProtocolKind() const {
507-
return std::get<0>(getStorage());
508-
}
509-
const DeclContext *getDeclContext() const {
510-
return std::get<1>(getStorage());
511-
}
512-
513-
static const char *getTypeName(KnownProtocolKind);
514-
static bool getPerformLocalLookup(KnownProtocolKind);
515-
TypeChecker &getTypeChecker() const;
516-
SourceFile *getSourceFile() const;
517-
Type &getCache() const;
518506
};
519507

520508
/// Retrieve information about a property wrapper type.

lib/AST/GenericSignatureBuilder.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4114,7 +4114,7 @@ ConstraintResult GenericSignatureBuilder::expandConformanceRequirement(
41144114
}
41154115

41164116
// Add any requirements in the where clause on the protocol.
4117-
RequirementRequest::visitRequirements(proto, TypeResolutionStage::Structural,
4117+
WhereClauseOwner(proto).visitRequirements(TypeResolutionStage::Structural,
41184118
[&](const Requirement &req, RequirementRepr *reqRepr) {
41194119
// If we're only looking at same-type constraints, skip everything else.
41204120
if (onlySameTypeConstraints &&
@@ -4240,8 +4240,8 @@ ConstraintResult GenericSignatureBuilder::expandConformanceRequirement(
42404240
}
42414241

42424242
// Add requirements from this associated type's where clause.
4243-
RequirementRequest::visitRequirements(assocTypeDecl,
4244-
TypeResolutionStage::Structural,
4243+
WhereClauseOwner(assocTypeDecl).visitRequirements(
4244+
TypeResolutionStage::Structural,
42454245
[&](const Requirement &req, RequirementRepr *reqRepr) {
42464246
// If we're only looking at same-type constraints, skip everything else.
42474247
if (onlySameTypeConstraints &&
@@ -7714,18 +7714,18 @@ InferredGenericSignatureRequest::evaluate(
77147714

77157715
// Add the requirements clause to the builder.
77167716

7717-
WhereClauseOwner owner(lookupDC, genericParams);
77187717
using FloatingRequirementSource =
77197718
GenericSignatureBuilder::FloatingRequirementSource;
7720-
RequirementRequest::visitRequirements(owner, TypeResolutionStage::Structural,
7719+
WhereClauseOwner(lookupDC, genericParams).visitRequirements(
7720+
TypeResolutionStage::Structural,
77217721
[&](const Requirement &req, RequirementRepr *reqRepr) {
77227722
auto source = FloatingRequirementSource::forExplicit(reqRepr);
77237723

77247724
// If we're extending a protocol and adding a redundant requirement,
77257725
// for example, `extension Foo where Self: Foo`, then emit a
77267726
// diagnostic.
77277727

7728-
if (auto decl = owner.dc->getAsDecl()) {
7728+
if (auto decl = lookupDC->getAsDecl()) {
77297729
if (auto extDecl = dyn_cast<ExtensionDecl>(decl)) {
77307730
auto extType = extDecl->getDeclaredInterfaceType();
77317731
auto extSelfType = extDecl->getSelfInterfaceType();

lib/AST/TypeCheckRequests.cpp

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -364,20 +364,19 @@ SourceLoc RequirementRequest::getNearestLoc() const {
364364
return owner.getLoc();
365365
}
366366

367-
MutableArrayRef<RequirementRepr>
368-
RequirementRequest::getRequirements(WhereClauseOwner owner) {
369-
if (auto genericParams = owner.source.dyn_cast<GenericParamList *>()) {
367+
MutableArrayRef<RequirementRepr> WhereClauseOwner::getRequirements() const {
368+
if (auto genericParams = source.dyn_cast<GenericParamList *>()) {
370369
return genericParams->getRequirements();
371370
}
372371

373-
if (auto attr = owner.source.dyn_cast<SpecializeAttr *>()) {
372+
if (auto attr = source.dyn_cast<SpecializeAttr *>()) {
374373
if (auto whereClause = attr->getTrailingWhereClause())
375374
return whereClause->getRequirements();
376375

377376
return { };
378377
}
379378

380-
auto decl = owner.source.dyn_cast<Decl *>();
379+
auto decl = source.dyn_cast<Decl *>();
381380
if (!decl)
382381
return { };
383382

@@ -401,14 +400,15 @@ RequirementRequest::getRequirements(WhereClauseOwner owner) {
401400
return { };
402401
}
403402

404-
bool RequirementRequest::visitRequirements(
405-
WhereClauseOwner owner, TypeResolutionStage stage,
406-
llvm::function_ref<bool(Requirement, RequirementRepr*)> callback) {
407-
auto &evaluator = owner.dc->getASTContext().evaluator;
408-
auto requirements = getRequirements(owner);
403+
bool WhereClauseOwner::visitRequirements(
404+
TypeResolutionStage stage,
405+
llvm::function_ref<bool(Requirement, RequirementRepr *)> callback)
406+
const && {
407+
auto &evaluator = dc->getASTContext().evaluator;
408+
auto requirements = getRequirements();
409409
for (unsigned index : indices(requirements)) {
410410
// Resolve to a requirement.
411-
auto req = evaluator(RequirementRequest{owner, index, stage});
411+
auto req = evaluator(RequirementRequest{*this, index, stage});
412412
if (req) {
413413
// Invoke the callback. If it returns true, we're done.
414414
if (callback(*req, &requirements[index]))
@@ -417,10 +417,10 @@ bool RequirementRequest::visitRequirements(
417417
continue;
418418
}
419419

420-
llvm::handleAllErrors(req.takeError(),
421-
[](const CyclicalRequestError<RequirementRequest> &E) {
422-
// cycle detected
423-
});
420+
llvm::handleAllErrors(
421+
req.takeError(), [](const CyclicalRequestError<RequirementRequest> &E) {
422+
// cycle detected
423+
});
424424
}
425425

426426
return false;
@@ -429,7 +429,7 @@ bool RequirementRequest::visitRequirements(
429429
RequirementRepr &RequirementRequest::getRequirement() const {
430430
auto owner = std::get<0>(getStorage());
431431
auto index = std::get<1>(getStorage());
432-
return getRequirements(owner)[index];
432+
return owner.getRequirements()[index];
433433
}
434434

435435
bool RequirementRequest::isCached() const {
@@ -503,49 +503,20 @@ void swift::simple_display(llvm::raw_ostream &out,
503503
// DefaultTypeRequest caching.
504504
//----------------------------------------------------------------------------//
505505

506-
SourceFile *DefaultTypeRequest::getSourceFile() const {
507-
return getDeclContext()->getParentSourceFile();
508-
}
509-
510-
Type &DefaultTypeRequest::getCache() const {
511-
return getDeclContext()->getASTContext().getDefaultTypeRequestCache(
512-
getSourceFile(), getKnownProtocolKind());
513-
}
514-
515506
Optional<Type> DefaultTypeRequest::getCachedResult() const {
516-
auto const &cachedType = getCache();
507+
auto *DC = std::get<1>(getStorage());
508+
auto knownProtocolKind = std::get<0>(getStorage());
509+
const auto &cachedType = DC->getASTContext().getDefaultTypeRequestCache(
510+
DC->getParentSourceFile(), knownProtocolKind);
517511
return cachedType ? Optional<Type>(cachedType) : None;
518512
}
519513

520-
void DefaultTypeRequest::cacheResult(Type value) const { getCache() = value; }
521-
522-
const char *
523-
DefaultTypeRequest::getTypeName(const KnownProtocolKind knownProtocolKind) {
524-
switch (knownProtocolKind) {
525-
526-
// clang-format off
527-
# define EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME(Id, Name, typeName, performLocalLookup) \
528-
case KnownProtocolKind::Id: return typeName;
529-
# include "swift/AST/KnownProtocols.def"
530-
# undef EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME
531-
//clang-format on
532-
533-
default: return nullptr;
534-
}
535-
}
536-
537-
bool DefaultTypeRequest::getPerformLocalLookup(const KnownProtocolKind knownProtocolKind) {
538-
switch (knownProtocolKind) {
539-
540-
// clang-format off
541-
# define EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME(Id, Name, typeName, performLocalLookup) \
542-
case KnownProtocolKind::Id: return performLocalLookup;
543-
# include "swift/AST/KnownProtocols.def"
544-
# undef EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME
545-
//clang-format on
546-
547-
default: return false;
548-
}
514+
void DefaultTypeRequest::cacheResult(Type value) const {
515+
auto *DC = std::get<1>(getStorage());
516+
auto knownProtocolKind = std::get<0>(getStorage());
517+
auto &cacheEntry = DC->getASTContext().getDefaultTypeRequestCache(
518+
DC->getParentSourceFile(), knownProtocolKind);
519+
cacheEntry = value;
549520
}
550521

551522
bool PropertyWrapperTypeInfoRequest::isCached() const {

lib/Sema/TypeCheckAccess.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@ enum class DowngradeToWarning: bool {
4444
/// Calls \p callback for each type in each requirement provided by
4545
/// \p source.
4646
static void forAllRequirementTypes(
47-
WhereClauseOwner source,
47+
WhereClauseOwner &&source,
4848
llvm::function_ref<void(Type, TypeRepr *)> callback) {
49-
RequirementRequest::visitRequirements(
50-
source, TypeResolutionStage::Interface,
51-
[&](const Requirement &req, RequirementRepr* reqRepr) {
49+
std::move(source).visitRequirements(TypeResolutionStage::Interface,
50+
[&](const Requirement &req, RequirementRepr *reqRepr) {
5251
switch (req.getKind()) {
5352
case RequirementKind::Conformance:
5453
case RequirementKind::SameType:
@@ -95,11 +94,11 @@ class AccessControlCheckerBase {
9594
}
9695

9796
void checkRequirementAccess(
98-
WhereClauseOwner source,
97+
WhereClauseOwner &&source,
9998
AccessScope accessScope,
10099
const DeclContext *useDC,
101100
llvm::function_ref<CheckTypeAccessCallback> diagnose) {
102-
forAllRequirementTypes(source, [&](Type type, TypeRepr *typeRepr) {
101+
forAllRequirementTypes(std::move(source), [&](Type type, TypeRepr *typeRepr) {
103102
checkTypeAccessImpl(type, typeRepr, accessScope, useDC,
104103
/*mayBeInferred*/false, diagnose);
105104
});

lib/Sema/TypeCheckAttr.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,8 +1852,7 @@ void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) {
18521852
SmallPtrSet<TypeBase *, 4> constrainedGenericParams;
18531853

18541854
// Go over the set of requirements, adding them to the builder.
1855-
RequirementRequest::visitRequirements(
1856-
WhereClauseOwner(FD, attr), TypeResolutionStage::Interface,
1855+
WhereClauseOwner(FD, attr).visitRequirements(TypeResolutionStage::Interface,
18571856
[&](const Requirement &req, RequirementRepr *reqRepr) {
18581857
// Collect all of the generic parameters used by these types.
18591858
switch (req.getKind()) {

lib/Sema/TypeCheckDecl.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,7 @@ static void checkInheritanceClause(
485485
/// Check the inheritance clauses generic parameters along with any
486486
/// requirements stored within the generic parameter list.
487487
static void checkGenericParams(GenericParamList *genericParams,
488-
DeclContext *owningDC,
489-
TypeChecker &tc) {
488+
DeclContext *owningDC, TypeChecker &tc) {
490489
if (!genericParams)
491490
return;
492491

@@ -496,12 +495,9 @@ static void checkGenericParams(GenericParamList *genericParams,
496495
}
497496

498497
// Force visitation of each of the requirements here.
499-
RequirementRequest::visitRequirements(WhereClauseOwner(owningDC,
500-
genericParams),
501-
TypeResolutionStage::Interface,
502-
[](Requirement, RequirementRepr *) {
503-
return false;
504-
});
498+
WhereClauseOwner(owningDC, genericParams)
499+
.visitRequirements(TypeResolutionStage::Interface,
500+
[](Requirement, RequirementRepr *) { return false; });
505501
}
506502

507503
/// Retrieve the set of protocols the given protocol inherits.
@@ -2011,7 +2007,8 @@ SelfAccessKindRequest::evaluate(Evaluator &evaluator, FuncDecl *FD) const {
20112007
/// to ensure that they don't introduce additional 'Self' requirements.
20122008
static void checkProtocolSelfRequirements(ProtocolDecl *proto,
20132009
TypeDecl *source) {
2014-
RequirementRequest::visitRequirements(source, TypeResolutionStage::Interface,
2010+
WhereClauseOwner(source).visitRequirements(
2011+
TypeResolutionStage::Interface,
20152012
[&](const Requirement &req, RequirementRepr *reqRepr) {
20162013
switch (req.getKind()) {
20172014
case RequirementKind::Conformance:

lib/Sema/TypeCheckExpr.cpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -663,39 +663,54 @@ static Optional<KnownProtocolKind>
663663
getKnownProtocolKindIfAny(const ProtocolDecl *protocol) {
664664
TypeChecker &tc = TypeChecker::createForContext(protocol->getASTContext());
665665

666-
// clang-format off
667-
#define EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME(Id, _, __, ___) \
668-
if (protocol == tc.getProtocol(SourceLoc(), KnownProtocolKind::Id)) \
669-
return KnownProtocolKind::Id;
670-
#include "swift/AST/KnownProtocols.def"
671-
#undef EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME
672-
// clang-format on
666+
#define EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME(Id, _, __, ___) \
667+
if (protocol == tc.getProtocol(SourceLoc(), KnownProtocolKind::Id)) \
668+
return KnownProtocolKind::Id;
669+
#include "swift/AST/KnownProtocols.def"
670+
#undef EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME
673671

674672
return None;
675673
}
676674

677675
Type TypeChecker::getDefaultType(ProtocolDecl *protocol, DeclContext *dc) {
678676
if (auto knownProtocolKindIfAny = getKnownProtocolKindIfAny(protocol)) {
679-
Type t = evaluateOrDefault(
677+
return evaluateOrDefault(
680678
Context.evaluator,
681679
DefaultTypeRequest{knownProtocolKindIfAny.getValue(), dc}, nullptr);
682-
return t;
683680
}
684-
return nullptr;
681+
return Type();
682+
}
683+
684+
static std::pair<const char *, bool> lookupDefaultTypeInfoForKnownProtocol(
685+
const KnownProtocolKind knownProtocolKind) {
686+
switch (knownProtocolKind) {
687+
#define EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME(Id, Name, typeName, \
688+
performLocalLookup) \
689+
case KnownProtocolKind::Id: \
690+
return {typeName, performLocalLookup};
691+
#include "swift/AST/KnownProtocols.def"
692+
#undef EXPRESSIBLE_BY_LITERAL_PROTOCOL_WITH_NAME
693+
default:
694+
return {nullptr, false};
695+
}
685696
}
686697

687698
llvm::Expected<Type>
688699
swift::DefaultTypeRequest::evaluate(Evaluator &evaluator,
689700
KnownProtocolKind knownProtocolKind,
690701
const DeclContext *dc) const {
691-
const char *const name = getTypeName(knownProtocolKind);
702+
const char *name;
703+
bool performLocalLookup;
704+
std::tie(name, performLocalLookup) =
705+
lookupDefaultTypeInfoForKnownProtocol(knownProtocolKind);
692706
if (!name)
693707
return nullptr;
694708

695-
TypeChecker &tc = getTypeChecker();
709+
// FIXME: Creating a whole type checker just to do lookup is unnecessary.
710+
TypeChecker &tc = TypeChecker::createForContext(dc->getASTContext());
696711

697712
Type type;
698-
if (getPerformLocalLookup(knownProtocolKind))
713+
if (performLocalLookup)
699714
type = lookupDefaultLiteralType(tc, dc, name);
700715

701716
if (!type)
@@ -710,10 +725,6 @@ swift::DefaultTypeRequest::evaluate(Evaluator &evaluator,
710725
return type;
711726
}
712727

713-
TypeChecker &DefaultTypeRequest::getTypeChecker() const {
714-
return TypeChecker::createForContext(getDeclContext()->getASTContext());
715-
}
716-
717728
Expr *TypeChecker::foldSequence(SequenceExpr *expr, DeclContext *dc) {
718729
ArrayRef<Expr*> Elts = expr->getElements();
719730
assert(Elts.size() > 1 && "inadequate number of elements in sequence");

0 commit comments

Comments
 (0)