Skip to content

Commit ce93261

Browse files
authored
Merge pull request #60310 from slavapestov/clean-up-override-substitutions
Clean up calculation of override substitutions
2 parents 1f6f969 + affc39a commit ce93261

15 files changed

+267
-299
lines changed

include/swift/AST/ASTContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,12 @@ class ASTContext final {
13441344
GenericSignature getOverrideGenericSignature(const ValueDecl *base,
13451345
const ValueDecl *derived);
13461346

1347+
GenericSignature
1348+
getOverrideGenericSignature(const NominalTypeDecl *baseNominal,
1349+
const NominalTypeDecl *derivedNominal,
1350+
GenericSignature baseGenericSig,
1351+
const GenericParamList *derivedParams);
1352+
13471353
enum class OverrideGenericSignatureReqCheck {
13481354
/// Base method's generic requirements are satisfied by derived method
13491355
BaseReqSatisfiedByDerived,

include/swift/AST/SubstitutionMap.h

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace llvm {
3333
namespace swift {
3434

3535
class GenericEnvironment;
36+
class GenericParamList;
3637
class SubstitutableType;
3738
typedef CanTypeWrapper<GenericTypeParamType> CanGenericTypeParamType;
3839

@@ -199,17 +200,15 @@ class SubstitutionMap {
199200
/// written in terms of the generic signature of 'baseDecl'.
200201
static SubstitutionMap
201202
getOverrideSubstitutions(const ValueDecl *baseDecl,
202-
const ValueDecl *derivedDecl,
203-
Optional<SubstitutionMap> derivedSubs);
203+
const ValueDecl *derivedDecl);
204204

205205
/// Variant of the above for when we have the generic signatures but not
206206
/// the decls for 'derived' and 'base'.
207207
static SubstitutionMap
208-
getOverrideSubstitutions(const ClassDecl *baseClass,
209-
const ClassDecl *derivedClass,
208+
getOverrideSubstitutions(const NominalTypeDecl *baseNominal,
209+
const NominalTypeDecl *derivedNominal,
210210
GenericSignature baseSig,
211-
GenericSignature derivedSig,
212-
Optional<SubstitutionMap> derivedSubs);
211+
const GenericParamList *derivedParams);
213212

214213
/// Combine two substitution maps as follows.
215214
///
@@ -313,6 +312,39 @@ class LookUpConformanceInSubstitutionMap {
313312
ProtocolDecl *conformedProtocol) const;
314313
};
315314

315+
struct OverrideSubsInfo {
316+
ASTContext &Ctx;
317+
unsigned BaseDepth;
318+
unsigned OrigDepth;
319+
SubstitutionMap BaseSubMap;
320+
const GenericParamList *DerivedParams;
321+
322+
OverrideSubsInfo(const NominalTypeDecl *baseNominal,
323+
const NominalTypeDecl *derivedNominal,
324+
GenericSignature baseSig,
325+
const GenericParamList *derivedParams);
326+
};
327+
328+
struct QueryOverrideSubs {
329+
OverrideSubsInfo info;
330+
331+
explicit QueryOverrideSubs(const OverrideSubsInfo &info)
332+
: info(info) {}
333+
334+
Type operator()(SubstitutableType *type) const;
335+
};
336+
337+
struct LookUpConformanceInOverrideSubs {
338+
OverrideSubsInfo info;
339+
340+
explicit LookUpConformanceInOverrideSubs(const OverrideSubsInfo &info)
341+
: info(info) {}
342+
343+
ProtocolConformanceRef operator()(CanType type,
344+
Type substType,
345+
ProtocolDecl *proto) const;
346+
};
347+
316348
} // end namespace swift
317349

318350
namespace llvm {

lib/AST/ASTContext.cpp

Lines changed: 52 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,18 @@ using AssociativityCacheType =
103103

104104
struct OverrideSignatureKey {
105105
GenericSignature baseMethodSig;
106-
GenericSignature derivedMethodSig;
107-
NominalTypeDecl *derivedNominal;
108-
109-
OverrideSignatureKey(GenericSignature baseMethodSignature,
110-
GenericSignature derivedMethodSignature,
111-
NominalTypeDecl *derivedNominal)
112-
: baseMethodSig(baseMethodSignature),
113-
derivedMethodSig(derivedMethodSignature),
114-
derivedNominal(derivedNominal) {}
106+
const NominalTypeDecl *baseNominal;
107+
const NominalTypeDecl *derivedNominal;
108+
const GenericParamList *derivedParams;
109+
110+
OverrideSignatureKey(GenericSignature baseMethodSig,
111+
const NominalTypeDecl *baseNominal,
112+
const NominalTypeDecl *derivedNominal,
113+
const GenericParamList *derivedParams)
114+
: baseMethodSig(baseMethodSig),
115+
baseNominal(baseNominal),
116+
derivedNominal(derivedNominal),
117+
derivedParams(derivedParams) {}
115118
};
116119

117120
namespace llvm {
@@ -122,28 +125,32 @@ template <> struct DenseMapInfo<OverrideSignatureKey> {
122125
static bool isEqual(const OverrideSignatureKey lhs,
123126
const OverrideSignatureKey rhs) {
124127
return lhs.baseMethodSig.getPointer() == rhs.baseMethodSig.getPointer() &&
125-
lhs.derivedMethodSig.getPointer() == rhs.derivedMethodSig.getPointer() &&
126-
lhs.derivedNominal == rhs.derivedNominal;
128+
lhs.baseNominal == rhs.baseNominal &&
129+
lhs.derivedNominal == rhs.derivedNominal &&
130+
lhs.derivedParams == rhs.derivedParams;
127131
}
128132

129133
static inline OverrideSignatureKey getEmptyKey() {
130134
return OverrideSignatureKey(DenseMapInfo<GenericSignature>::getEmptyKey(),
131-
DenseMapInfo<GenericSignature>::getEmptyKey(),
132-
DenseMapInfo<NominalTypeDecl *>::getEmptyKey());
135+
DenseMapInfo<NominalTypeDecl *>::getEmptyKey(),
136+
DenseMapInfo<NominalTypeDecl *>::getEmptyKey(),
137+
DenseMapInfo<GenericParamList *>::getEmptyKey());
133138
}
134139

135140
static inline OverrideSignatureKey getTombstoneKey() {
136141
return OverrideSignatureKey(
137142
DenseMapInfo<GenericSignature>::getTombstoneKey(),
138-
DenseMapInfo<GenericSignature>::getTombstoneKey(),
139-
DenseMapInfo<NominalTypeDecl *>::getTombstoneKey());
143+
DenseMapInfo<NominalTypeDecl *>::getTombstoneKey(),
144+
DenseMapInfo<NominalTypeDecl *>::getTombstoneKey(),
145+
DenseMapInfo<GenericParamList *>::getTombstoneKey());
140146
}
141147

142148
static unsigned getHashValue(const OverrideSignatureKey &Val) {
143149
return hash_combine(
144150
DenseMapInfo<GenericSignature>::getHashValue(Val.baseMethodSig),
145-
DenseMapInfo<GenericSignature>::getHashValue(Val.derivedMethodSig),
146-
DenseMapInfo<NominalTypeDecl *>::getHashValue(Val.derivedNominal));
151+
DenseMapInfo<NominalTypeDecl *>::getHashValue(Val.baseNominal),
152+
DenseMapInfo<NominalTypeDecl *>::getHashValue(Val.derivedNominal),
153+
DenseMapInfo<GenericParamList *>::getHashValue(Val.derivedParams));
147154
}
148155
};
149156
} // namespace llvm
@@ -5271,86 +5278,56 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52715278

52725279
const auto baseGenericSig =
52735280
base->getAsGenericContext()->getGenericSignature();
5274-
const auto derivedGenericSig =
5275-
derived->getAsGenericContext()->getGenericSignature();
5281+
const auto *derivedParams =
5282+
derived->getAsGenericContext()->getGenericParams();
5283+
5284+
return getOverrideGenericSignature(baseNominal, derivedNominal,
5285+
baseGenericSig, derivedParams);
5286+
}
52765287

5277-
if (base == derived)
5278-
return derivedGenericSig;
5288+
GenericSignature
5289+
ASTContext::getOverrideGenericSignature(const NominalTypeDecl *baseNominal,
5290+
const NominalTypeDecl *derivedNominal,
5291+
GenericSignature baseGenericSig,
5292+
const GenericParamList *derivedParams) {
5293+
if (baseNominal == derivedNominal)
5294+
return baseGenericSig;
5295+
5296+
const auto derivedNominalSig = derivedNominal->getGenericSignature();
52795297

5280-
if (derivedGenericSig.isNull())
5298+
if (derivedNominalSig.isNull() && derivedParams == nullptr)
52815299
return nullptr;
52825300

52835301
if (baseGenericSig.isNull())
5284-
return derivedGenericSig;
5302+
return derivedNominalSig;
52855303

52865304
auto key = OverrideSignatureKey(baseGenericSig,
5287-
derivedGenericSig,
5288-
derivedNominal);
5305+
baseNominal,
5306+
derivedNominal,
5307+
derivedParams);
52895308

52905309
if (getImpl().overrideSigCache.find(key) !=
52915310
getImpl().overrideSigCache.end()) {
52925311
return getImpl().overrideSigCache.lookup(key);
52935312
}
52945313

5295-
const auto derivedNominalSig = derivedNominal->getGenericSignature();
5296-
52975314
SmallVector<GenericTypeParamType *, 2> addedGenericParams;
5298-
if (const auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
5299-
for (auto gp : *gpList) {
5315+
if (derivedParams) {
5316+
for (auto gp : *derivedParams) {
53005317
addedGenericParams.push_back(
53015318
gp->getDeclaredInterfaceType()->castTo<GenericTypeParamType>());
53025319
}
53035320
}
53045321

53055322
SmallVector<Requirement, 2> addedRequirements;
53065323

5307-
if (isa<ProtocolDecl>(baseNominal)) {
5308-
assert(isa<ProtocolDecl>(derivedNominal));
5324+
OverrideSubsInfo info(baseNominal, derivedNominal,
5325+
baseGenericSig, derivedParams);
53095326

5310-
for (auto reqt : baseGenericSig.getRequirements()) {
5311-
addedRequirements.push_back(reqt);
5312-
}
5313-
} else {
5314-
const auto derivedSuperclass = cast<ClassDecl>(derivedNominal)
5315-
->getSuperclass();
5316-
if (derivedSuperclass.isNull())
5317-
return nullptr;
5318-
5319-
unsigned derivedDepth = 0;
5320-
unsigned baseDepth = 0;
5321-
if (derivedNominalSig)
5322-
derivedDepth = derivedNominalSig.getGenericParams().back()->getDepth() + 1;
5323-
if (const auto baseNominalSig = baseNominal->getGenericSignature())
5324-
baseDepth = baseNominalSig.getGenericParams().back()->getDepth() + 1;
5325-
5326-
const auto subMap = derivedSuperclass->getContextSubstitutionMap(
5327-
derivedNominal->getModuleContext(), baseNominal);
5328-
5329-
auto substFn = [&](SubstitutableType *type) -> Type {
5330-
auto *gp = cast<GenericTypeParamType>(type);
5331-
5332-
if (gp->getDepth() < baseDepth) {
5333-
return Type(gp).subst(subMap);
5334-
}
5335-
5336-
return CanGenericTypeParamType::get(
5337-
gp->isTypeSequence(), gp->getDepth() - baseDepth + derivedDepth,
5338-
gp->getIndex(), *this);
5339-
};
5340-
5341-
auto lookupConformanceFn =
5342-
[&](CanType depTy, Type substTy,
5343-
ProtocolDecl *proto) -> ProtocolConformanceRef {
5344-
if (auto conf = subMap.lookupConformance(depTy, proto))
5345-
return conf;
5346-
5347-
return ProtocolConformanceRef(proto);
5348-
};
5349-
5350-
for (auto reqt : baseGenericSig.getRequirements()) {
5351-
if (auto substReqt = reqt.subst(substFn, lookupConformanceFn)) {
5352-
addedRequirements.push_back(*substReqt);
5353-
}
5327+
for (auto reqt : baseGenericSig.getRequirements()) {
5328+
if (auto substReqt = reqt.subst(QueryOverrideSubs(info),
5329+
LookUpConformanceInOverrideSubs(info))) {
5330+
addedRequirements.push_back(*substReqt);
53545331
}
53555332
}
53565333

lib/AST/ConcreteDeclRef.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@ ConcreteDeclRef ConcreteDeclRef::getOverriddenDecl() const {
3636

3737
SubstitutionMap subs;
3838
if (baseSig) {
39-
Optional<SubstitutionMap> derivedSubMap;
39+
subs = SubstitutionMap::getOverrideSubstitutions(baseDecl, derivedDecl);
4040
if (derivedSig)
41-
derivedSubMap = getSubstitutions();
42-
subs = SubstitutionMap::getOverrideSubstitutions(baseDecl, derivedDecl,
43-
derivedSubMap);
41+
subs = subs.subst(getSubstitutions());
4442
}
4543
return ConcreteDeclRef(baseDecl, subs);
4644
}

0 commit comments

Comments
 (0)