Skip to content

Commit 0b10948

Browse files
committed
AST: Generalize ASTContext::getOverrideGenericSignature() to work with protocols
We want to avoid feeding invalid type parameters into the Requirement Machine when checking if a protocol requirement overrides another protocol requirement in an inherited protocol. In order to do that we need to make sure the potential override has a compatible generic signature before we attempt substitution, just like we already do for classes. To do that, we need a way to 'rewrite' a generic signature for the base requirement into a generic signature that can be compared with the derived requirement, just as we do with class methods. For protocols this is a little easier than with class methods, since protocols only ever have a single 'Self' generic parameter. So just add the 'Self : DerivedProto' requirement to the base requirement's signature and rebuild.
1 parent 910c2a4 commit 0b10948

File tree

1 file changed

+56
-46
lines changed

1 file changed

+56
-46
lines changed

lib/AST/ASTContext.cpp

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ using AssociativityCacheType =
101101
struct OverrideSignatureKey {
102102
GenericSignature baseMethodSig;
103103
GenericSignature derivedMethodSig;
104-
Decl *subclassDecl;
104+
NominalTypeDecl *derivedNominal;
105105

106106
OverrideSignatureKey(GenericSignature baseMethodSignature,
107107
GenericSignature derivedMethodSignature,
108-
Decl *subclassDecl)
108+
NominalTypeDecl *derivedNominal)
109109
: baseMethodSig(baseMethodSignature),
110110
derivedMethodSig(derivedMethodSignature),
111-
subclassDecl(subclassDecl) {}
111+
derivedNominal(derivedNominal) {}
112112
};
113113

114114
namespace llvm {
@@ -120,27 +120,27 @@ template <> struct DenseMapInfo<OverrideSignatureKey> {
120120
const OverrideSignatureKey rhs) {
121121
return lhs.baseMethodSig.getPointer() == rhs.baseMethodSig.getPointer() &&
122122
lhs.derivedMethodSig.getPointer() == rhs.derivedMethodSig.getPointer() &&
123-
lhs.subclassDecl == rhs.subclassDecl;
123+
lhs.derivedNominal == rhs.derivedNominal;
124124
}
125125

126126
static inline OverrideSignatureKey getEmptyKey() {
127127
return OverrideSignatureKey(DenseMapInfo<GenericSignature>::getEmptyKey(),
128128
DenseMapInfo<GenericSignature>::getEmptyKey(),
129-
DenseMapInfo<Decl *>::getEmptyKey());
129+
DenseMapInfo<NominalTypeDecl *>::getEmptyKey());
130130
}
131131

132132
static inline OverrideSignatureKey getTombstoneKey() {
133133
return OverrideSignatureKey(
134134
DenseMapInfo<GenericSignature>::getTombstoneKey(),
135135
DenseMapInfo<GenericSignature>::getTombstoneKey(),
136-
DenseMapInfo<Decl *>::getTombstoneKey());
136+
DenseMapInfo<NominalTypeDecl *>::getTombstoneKey());
137137
}
138138

139139
static unsigned getHashValue(const OverrideSignatureKey &Val) {
140140
return hash_combine(
141141
DenseMapInfo<GenericSignature>::getHashValue(Val.baseMethodSig),
142142
DenseMapInfo<GenericSignature>::getHashValue(Val.derivedMethodSig),
143-
DenseMapInfo<Decl *>::getHashValue(Val.subclassDecl));
143+
DenseMapInfo<NominalTypeDecl *>::getHashValue(Val.derivedNominal));
144144
}
145145
};
146146
} // namespace llvm
@@ -5214,11 +5214,11 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52145214
assert(isa<AbstractFunctionDecl>(base) || isa<SubscriptDecl>(base));
52155215
assert(isa<AbstractFunctionDecl>(derived) || isa<SubscriptDecl>(derived));
52165216

5217-
const auto baseClass = base->getDeclContext()->getSelfClassDecl();
5218-
const auto derivedClass = derived->getDeclContext()->getSelfClassDecl();
5217+
const auto baseNominal = base->getDeclContext()->getSelfNominalTypeDecl();
5218+
const auto derivedNominal = derived->getDeclContext()->getSelfNominalTypeDecl();
52195219

5220-
assert(baseClass != nullptr);
5221-
assert(derivedClass != nullptr);
5220+
assert(baseNominal != nullptr);
5221+
assert(derivedNominal != nullptr);
52225222

52235223
const auto baseGenericSig =
52245224
base->getAsGenericContext()->getGenericSignature();
@@ -5228,10 +5228,6 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52285228
if (base == derived)
52295229
return derivedGenericSig;
52305230

5231-
const auto derivedSuperclass = derivedClass->getSuperclass();
5232-
if (derivedSuperclass.isNull())
5233-
return nullptr;
5234-
52355231
if (derivedGenericSig.isNull())
52365232
return nullptr;
52375233

@@ -5240,21 +5236,14 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52405236

52415237
auto key = OverrideSignatureKey(baseGenericSig,
52425238
derivedGenericSig,
5243-
derivedClass);
5239+
derivedNominal);
52445240

52455241
if (getImpl().overrideSigCache.find(key) !=
52465242
getImpl().overrideSigCache.end()) {
52475243
return getImpl().overrideSigCache.lookup(key);
52485244
}
52495245

5250-
const auto derivedClassSig = derivedClass->getGenericSignature();
5251-
5252-
unsigned derivedDepth = 0;
5253-
unsigned baseDepth = 0;
5254-
if (derivedClassSig)
5255-
derivedDepth = derivedClassSig.getGenericParams().back()->getDepth() + 1;
5256-
if (const auto baseClassSig = baseClass->getGenericSignature())
5257-
baseDepth = baseClassSig.getGenericParams().back()->getDepth() + 1;
5246+
const auto derivedNominalSig = derivedNominal->getGenericSignature();
52585247

52595248
SmallVector<GenericTypeParamType *, 2> addedGenericParams;
52605249
if (const auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
@@ -5264,38 +5253,59 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52645253
}
52655254
}
52665255

5267-
const auto subMap = derivedSuperclass->getContextSubstitutionMap(
5268-
derivedClass->getModuleContext(), baseClass);
5256+
SmallVector<Requirement, 2> addedRequirements;
52695257

5270-
auto substFn = [&](SubstitutableType *type) -> Type {
5271-
auto *gp = cast<GenericTypeParamType>(type);
5258+
if (isa<ProtocolDecl>(baseNominal)) {
5259+
assert(isa<ProtocolDecl>(derivedNominal));
52725260

5273-
if (gp->getDepth() < baseDepth) {
5274-
return Type(gp).subst(subMap);
5261+
for (auto reqt : baseGenericSig.getRequirements()) {
5262+
addedRequirements.push_back(reqt);
52755263
}
5264+
} else {
5265+
const auto derivedSuperclass = cast<ClassDecl>(derivedNominal)
5266+
->getSuperclass();
5267+
if (derivedSuperclass.isNull())
5268+
return nullptr;
52765269

5277-
return CanGenericTypeParamType::get(
5278-
gp->isTypeSequence(), gp->getDepth() - baseDepth + derivedDepth,
5279-
gp->getIndex(), *this);
5280-
};
5270+
unsigned derivedDepth = 0;
5271+
unsigned baseDepth = 0;
5272+
if (derivedNominalSig)
5273+
derivedDepth = derivedNominalSig.getGenericParams().back()->getDepth() + 1;
5274+
if (const auto baseNominalSig = baseNominal->getGenericSignature())
5275+
baseDepth = baseNominalSig.getGenericParams().back()->getDepth() + 1;
52815276

5282-
auto lookupConformanceFn =
5283-
[&](CanType depTy, Type substTy,
5284-
ProtocolDecl *proto) -> ProtocolConformanceRef {
5285-
if (auto conf = subMap.lookupConformance(depTy, proto))
5286-
return conf;
5277+
const auto subMap = derivedSuperclass->getContextSubstitutionMap(
5278+
derivedNominal->getModuleContext(), baseNominal);
52875279

5288-
return ProtocolConformanceRef(proto);
5289-
};
5280+
auto substFn = [&](SubstitutableType *type) -> Type {
5281+
auto *gp = cast<GenericTypeParamType>(type);
52905282

5291-
SmallVector<Requirement, 2> addedRequirements;
5292-
for (auto reqt : baseGenericSig.getRequirements()) {
5293-
if (auto substReqt = reqt.subst(substFn, lookupConformanceFn)) {
5294-
addedRequirements.push_back(*substReqt);
5283+
if (gp->getDepth() < baseDepth) {
5284+
return Type(gp).subst(subMap);
5285+
}
5286+
5287+
return CanGenericTypeParamType::get(
5288+
gp->isTypeSequence(), gp->getDepth() - baseDepth + derivedDepth,
5289+
gp->getIndex(), *this);
5290+
};
5291+
5292+
auto lookupConformanceFn =
5293+
[&](CanType depTy, Type substTy,
5294+
ProtocolDecl *proto) -> ProtocolConformanceRef {
5295+
if (auto conf = subMap.lookupConformance(depTy, proto))
5296+
return conf;
5297+
5298+
return ProtocolConformanceRef(proto);
5299+
};
5300+
5301+
for (auto reqt : baseGenericSig.getRequirements()) {
5302+
if (auto substReqt = reqt.subst(substFn, lookupConformanceFn)) {
5303+
addedRequirements.push_back(*substReqt);
5304+
}
52955305
}
52965306
}
52975307

5298-
auto genericSig = buildGenericSignature(*this, derivedClassSig,
5308+
auto genericSig = buildGenericSignature(*this, derivedNominalSig,
52995309
std::move(addedGenericParams),
53005310
std::move(addedRequirements));
53015311
getImpl().overrideSigCache.insert(std::make_pair(key, genericSig));

0 commit comments

Comments
 (0)