@@ -103,15 +103,18 @@ using AssociativityCacheType =
103
103
104
104
struct OverrideSignatureKey {
105
105
GenericSignature baseMethodSig;
106
- GenericSignature derivedMethodSig;
107
- NominalTypeDecl *derivedNominal;
108
-
109
- OverrideSignatureKey (GenericSignature baseMethodSignature,
110
- GenericSignature derivedMethodSignature,
111
- NominalTypeDecl *derivedNominal)
112
- : baseMethodSig(baseMethodSignature),
113
- derivedMethodSig (derivedMethodSignature),
114
- derivedNominal(derivedNominal) {}
106
+ const NominalTypeDecl *baseNominal;
107
+ const NominalTypeDecl *derivedNominal;
108
+ const GenericParamList *derivedParams;
109
+
110
+ OverrideSignatureKey (GenericSignature baseMethodSig,
111
+ const NominalTypeDecl *baseNominal,
112
+ const NominalTypeDecl *derivedNominal,
113
+ const GenericParamList *derivedParams)
114
+ : baseMethodSig(baseMethodSig),
115
+ baseNominal (baseNominal),
116
+ derivedNominal(derivedNominal),
117
+ derivedParams(derivedParams) {}
115
118
};
116
119
117
120
namespace llvm {
@@ -122,28 +125,32 @@ template <> struct DenseMapInfo<OverrideSignatureKey> {
122
125
static bool isEqual (const OverrideSignatureKey lhs,
123
126
const OverrideSignatureKey rhs) {
124
127
return lhs.baseMethodSig .getPointer () == rhs.baseMethodSig .getPointer () &&
125
- lhs.derivedMethodSig .getPointer () == rhs.derivedMethodSig .getPointer () &&
126
- lhs.derivedNominal == rhs.derivedNominal ;
128
+ lhs.baseNominal == rhs.baseNominal &&
129
+ lhs.derivedNominal == rhs.derivedNominal &&
130
+ lhs.derivedParams == rhs.derivedParams ;
127
131
}
128
132
129
133
static inline OverrideSignatureKey getEmptyKey () {
130
134
return OverrideSignatureKey (DenseMapInfo<GenericSignature>::getEmptyKey (),
131
- DenseMapInfo<GenericSignature>::getEmptyKey (),
132
- DenseMapInfo<NominalTypeDecl *>::getEmptyKey ());
135
+ DenseMapInfo<NominalTypeDecl *>::getEmptyKey (),
136
+ DenseMapInfo<NominalTypeDecl *>::getEmptyKey (),
137
+ DenseMapInfo<GenericParamList *>::getEmptyKey ());
133
138
}
134
139
135
140
static inline OverrideSignatureKey getTombstoneKey () {
136
141
return OverrideSignatureKey (
137
142
DenseMapInfo<GenericSignature>::getTombstoneKey (),
138
- DenseMapInfo<GenericSignature>::getTombstoneKey (),
139
- DenseMapInfo<NominalTypeDecl *>::getTombstoneKey ());
143
+ DenseMapInfo<NominalTypeDecl *>::getTombstoneKey (),
144
+ DenseMapInfo<NominalTypeDecl *>::getTombstoneKey (),
145
+ DenseMapInfo<GenericParamList *>::getTombstoneKey ());
140
146
}
141
147
142
148
static unsigned getHashValue (const OverrideSignatureKey &Val) {
143
149
return hash_combine (
144
150
DenseMapInfo<GenericSignature>::getHashValue (Val.baseMethodSig ),
145
- DenseMapInfo<GenericSignature>::getHashValue (Val.derivedMethodSig ),
146
- DenseMapInfo<NominalTypeDecl *>::getHashValue (Val.derivedNominal ));
151
+ DenseMapInfo<NominalTypeDecl *>::getHashValue (Val.baseNominal ),
152
+ DenseMapInfo<NominalTypeDecl *>::getHashValue (Val.derivedNominal ),
153
+ DenseMapInfo<GenericParamList *>::getHashValue (Val.derivedParams ));
147
154
}
148
155
};
149
156
} // namespace llvm
@@ -5271,86 +5278,56 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
5271
5278
5272
5279
const auto baseGenericSig =
5273
5280
base->getAsGenericContext ()->getGenericSignature ();
5274
- const auto derivedGenericSig =
5275
- derived->getAsGenericContext ()->getGenericSignature ();
5281
+ const auto *derivedParams =
5282
+ derived->getAsGenericContext ()->getGenericParams ();
5283
+
5284
+ return getOverrideGenericSignature (baseNominal, derivedNominal,
5285
+ baseGenericSig, derivedParams);
5286
+ }
5276
5287
5277
- if (base == derived)
5278
- return derivedGenericSig;
5288
+ GenericSignature
5289
+ ASTContext::getOverrideGenericSignature (const NominalTypeDecl *baseNominal,
5290
+ const NominalTypeDecl *derivedNominal,
5291
+ GenericSignature baseGenericSig,
5292
+ const GenericParamList *derivedParams) {
5293
+ if (baseNominal == derivedNominal)
5294
+ return baseGenericSig;
5295
+
5296
+ const auto derivedNominalSig = derivedNominal->getGenericSignature ();
5279
5297
5280
- if (derivedGenericSig .isNull ())
5298
+ if (derivedNominalSig .isNull () && derivedParams == nullptr )
5281
5299
return nullptr ;
5282
5300
5283
5301
if (baseGenericSig.isNull ())
5284
- return derivedGenericSig ;
5302
+ return derivedNominalSig ;
5285
5303
5286
5304
auto key = OverrideSignatureKey (baseGenericSig,
5287
- derivedGenericSig,
5288
- derivedNominal);
5305
+ baseNominal,
5306
+ derivedNominal,
5307
+ derivedParams);
5289
5308
5290
5309
if (getImpl ().overrideSigCache .find (key) !=
5291
5310
getImpl ().overrideSigCache .end ()) {
5292
5311
return getImpl ().overrideSigCache .lookup (key);
5293
5312
}
5294
5313
5295
- const auto derivedNominalSig = derivedNominal->getGenericSignature ();
5296
-
5297
5314
SmallVector<GenericTypeParamType *, 2 > addedGenericParams;
5298
- if (const auto *gpList = derived-> getAsGenericContext ()-> getGenericParams () ) {
5299
- for (auto gp : *gpList ) {
5315
+ if (derivedParams ) {
5316
+ for (auto gp : *derivedParams ) {
5300
5317
addedGenericParams.push_back (
5301
5318
gp->getDeclaredInterfaceType ()->castTo <GenericTypeParamType>());
5302
5319
}
5303
5320
}
5304
5321
5305
5322
SmallVector<Requirement, 2 > addedRequirements;
5306
5323
5307
- if (isa<ProtocolDecl>( baseNominal)) {
5308
- assert (isa<ProtocolDecl>(derivedNominal) );
5324
+ OverrideSubsInfo info ( baseNominal, derivedNominal,
5325
+ baseGenericSig, derivedParams );
5309
5326
5310
- for (auto reqt : baseGenericSig.getRequirements ()) {
5311
- addedRequirements.push_back (reqt);
5312
- }
5313
- } else {
5314
- const auto derivedSuperclass = cast<ClassDecl>(derivedNominal)
5315
- ->getSuperclass ();
5316
- if (derivedSuperclass.isNull ())
5317
- return nullptr ;
5318
-
5319
- unsigned derivedDepth = 0 ;
5320
- unsigned baseDepth = 0 ;
5321
- if (derivedNominalSig)
5322
- derivedDepth = derivedNominalSig.getGenericParams ().back ()->getDepth () + 1 ;
5323
- if (const auto baseNominalSig = baseNominal->getGenericSignature ())
5324
- baseDepth = baseNominalSig.getGenericParams ().back ()->getDepth () + 1 ;
5325
-
5326
- const auto subMap = derivedSuperclass->getContextSubstitutionMap (
5327
- derivedNominal->getModuleContext (), baseNominal);
5328
-
5329
- auto substFn = [&](SubstitutableType *type) -> Type {
5330
- auto *gp = cast<GenericTypeParamType>(type);
5331
-
5332
- if (gp->getDepth () < baseDepth) {
5333
- return Type (gp).subst (subMap);
5334
- }
5335
-
5336
- return CanGenericTypeParamType::get (
5337
- gp->isTypeSequence (), gp->getDepth () - baseDepth + derivedDepth,
5338
- gp->getIndex (), *this );
5339
- };
5340
-
5341
- auto lookupConformanceFn =
5342
- [&](CanType depTy, Type substTy,
5343
- ProtocolDecl *proto) -> ProtocolConformanceRef {
5344
- if (auto conf = subMap.lookupConformance (depTy, proto))
5345
- return conf;
5346
-
5347
- return ProtocolConformanceRef (proto);
5348
- };
5349
-
5350
- for (auto reqt : baseGenericSig.getRequirements ()) {
5351
- if (auto substReqt = reqt.subst (substFn, lookupConformanceFn)) {
5352
- addedRequirements.push_back (*substReqt);
5353
- }
5327
+ for (auto reqt : baseGenericSig.getRequirements ()) {
5328
+ if (auto substReqt = reqt.subst (QueryOverrideSubs (info),
5329
+ LookUpConformanceInOverrideSubs (info))) {
5330
+ addedRequirements.push_back (*substReqt);
5354
5331
}
5355
5332
}
5356
5333
0 commit comments