Skip to content

Commit bf91cf9

Browse files
committed
Clean up WhereClauseOwner
1 parent 2b7dd20 commit bf91cf9

File tree

3 files changed

+28
-42
lines changed

3 files changed

+28
-42
lines changed

include/swift/AST/TypeCheckRequests.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct PropertyWrapperBackingPropertyInfo;
3939
struct PropertyWrapperMutability;
4040
class RequirementRepr;
4141
class SpecializeAttr;
42+
class TrailingWhereClause;
4243
class TypeAliasDecl;
4344
struct TypeLoc;
4445
class Witness;
@@ -369,9 +370,10 @@ struct WhereClauseOwner {
369370

370371
/// The source of the where clause, which can be a generic parameter list
371372
/// or a declaration that can have a where clause.
372-
llvm::PointerUnion<GenericParamList *, Decl *, SpecializeAttr *> source;
373+
llvm::PointerUnion<GenericParamList *, TrailingWhereClause *, SpecializeAttr *> source;
373374

374-
WhereClauseOwner(Decl *decl);
375+
WhereClauseOwner(GenericContext *genCtx);
376+
WhereClauseOwner(AssociatedTypeDecl *atd);
375377

376378
WhereClauseOwner(DeclContext *dc, GenericParamList *genericParams)
377379
: dc(dc), source(genericParams) {}
@@ -382,13 +384,12 @@ struct WhereClauseOwner {
382384
SourceLoc getLoc() const;
383385

384386
friend hash_code hash_value(const WhereClauseOwner &owner) {
385-
return llvm::hash_combine(owner.dc, owner.source.getOpaqueValue());
387+
return llvm::hash_value(owner.source.getOpaqueValue());
386388
}
387389

388390
friend bool operator==(const WhereClauseOwner &lhs,
389391
const WhereClauseOwner &rhs) {
390-
return lhs.dc == rhs.dc &&
391-
lhs.source.getOpaqueValue() == rhs.source.getOpaqueValue();
392+
return lhs.source.getOpaqueValue() == rhs.source.getOpaqueValue();
392393
}
393394

394395
friend bool operator!=(const WhereClauseOwner &lhs,

lib/AST/TypeCheckRequests.cpp

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,20 @@ void RequirementSignatureRequest::cacheResult(ArrayRef<Requirement> value) const
337337
// Requirement computation.
338338
//----------------------------------------------------------------------------//
339339

340-
WhereClauseOwner::WhereClauseOwner(Decl *decl)
341-
: dc(decl->getInnermostDeclContext()), source(decl) { }
340+
WhereClauseOwner::WhereClauseOwner(GenericContext *genCtx): dc(genCtx) {
341+
if (const auto whereClause = genCtx->getTrailingWhereClause())
342+
source = whereClause;
343+
else
344+
source = genCtx->getGenericParams();
345+
}
346+
347+
WhereClauseOwner::WhereClauseOwner(AssociatedTypeDecl *atd)
348+
: dc(atd->getInnermostDeclContext()),
349+
source(atd->getTrailingWhereClause()) {}
342350

343351
SourceLoc WhereClauseOwner::getLoc() const {
344-
if (auto decl = source.dyn_cast<Decl *>())
345-
return decl->getLoc();
352+
if (auto where = source.dyn_cast<TrailingWhereClause *>())
353+
return where->getWhereLoc();
346354

347355
if (auto attr = source.dyn_cast<SpecializeAttr *>())
348356
return attr->getLocation();
@@ -352,8 +360,8 @@ SourceLoc WhereClauseOwner::getLoc() const {
352360

353361
void swift::simple_display(llvm::raw_ostream &out,
354362
const WhereClauseOwner &owner) {
355-
if (auto decl = owner.source.dyn_cast<Decl *>()) {
356-
simple_display(out, decl);
363+
if (auto where = owner.source.dyn_cast<TrailingWhereClause *>()) {
364+
simple_display(out, owner.dc->getAsDecl());
357365
} else if (owner.source.is<SpecializeAttr *>()) {
358366
out << "@_specialize";
359367
} else {
@@ -375,36 +383,13 @@ void RequirementRequest::noteCycleStep(DiagnosticEngine &diags) const {
375383
}
376384

377385
MutableArrayRef<RequirementRepr> WhereClauseOwner::getRequirements() const {
378-
if (auto genericParams = source.dyn_cast<GenericParamList *>()) {
386+
if (const auto genericParams = source.dyn_cast<GenericParamList *>()) {
379387
return genericParams->getRequirements();
380-
}
381-
382-
if (auto attr = source.dyn_cast<SpecializeAttr *>()) {
388+
} else if (const auto attr = source.dyn_cast<SpecializeAttr *>()) {
383389
if (auto whereClause = attr->getTrailingWhereClause())
384390
return whereClause->getRequirements();
385-
386-
return { };
387-
}
388-
389-
auto decl = source.dyn_cast<Decl *>();
390-
if (!decl)
391-
return { };
392-
393-
if (auto proto = dyn_cast<ProtocolDecl>(decl)) {
394-
if (auto whereClause = proto->getTrailingWhereClause())
395-
return whereClause->getRequirements();
396-
397-
return { };
398-
}
399-
400-
if (auto assocType = dyn_cast<AssociatedTypeDecl>(decl)) {
401-
if (auto whereClause = assocType->getTrailingWhereClause())
402-
return whereClause->getRequirements();
403-
}
404-
405-
if (auto genericContext = decl->getAsGenericContext()) {
406-
if (auto genericParams = genericContext->getGenericParams())
407-
return genericParams->getRequirements();
391+
} else if (const auto whereClause = source.get<TrailingWhereClause *>()) {
392+
return whereClause->getRequirements();
408393
}
409394

410395
return { };

lib/Sema/TypeCheckDecl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,13 +2218,13 @@ SelfAccessKindRequest::evaluate(Evaluator &evaluator, FuncDecl *FD) const {
22182218
return SelfAccessKind::NonMutating;
22192219
}
22202220

2221-
/// Check the requirements in the where clause of the given \c source
2221+
/// Check the requirements in the where clause of the given \c atd
22222222
/// to ensure that they don't introduce additional 'Self' requirements.
22232223
static void checkProtocolSelfRequirements(ProtocolDecl *proto,
2224-
TypeDecl *source) {
2225-
WhereClauseOwner(source).visitRequirements(
2224+
AssociatedTypeDecl *atd) {
2225+
WhereClauseOwner(atd).visitRequirements(
22262226
TypeResolutionStage::Interface,
2227-
[&](const Requirement &req, RequirementRepr *reqRepr) {
2227+
[proto](const Requirement &req, RequirementRepr *reqRepr) {
22282228
switch (req.getKind()) {
22292229
case RequirementKind::Conformance:
22302230
case RequirementKind::Layout:

0 commit comments

Comments
 (0)