@@ -294,21 +294,26 @@ class InheritedProtocolCollector {
294
294
const NominalTypeDecl *nominal;
295
295
const IterableDeclContext *memberContext;
296
296
297
+ auto shouldInclude = [](const ExtensionDecl *extension) {
298
+ if (extension->isConstrainedExtension ()) {
299
+ // Conditional conformances never apply to inherited protocols, nor
300
+ // can they provide unconditional conformances that might be used in
301
+ // other extensions.
302
+ return false ;
303
+ }
304
+ return true ;
305
+ };
297
306
if ((nominal = dyn_cast<NominalTypeDecl>(D))) {
298
307
directlyInherited = nominal->getInherited ();
299
308
memberContext = nominal;
300
309
301
310
} else if (auto *extension = dyn_cast<ExtensionDecl>(D)) {
302
- if (extension->isConstrainedExtension ()) {
303
- // Conditional conformances never apply to inherited protocols, nor
304
- // can they provide unconditional conformances that might be used in
305
- // other extensions.
311
+ if (!shouldInclude (extension)) {
306
312
return ;
307
313
}
308
314
nominal = extension->getExtendedNominal ();
309
315
directlyInherited = extension->getInherited ();
310
316
memberContext = extension;
311
-
312
317
} else {
313
318
return ;
314
319
}
@@ -317,6 +322,18 @@ class InheritedProtocolCollector {
317
322
return ;
318
323
319
324
map[nominal].recordProtocols (directlyInherited, D);
325
+ // Collect protocols inherited from super classes
326
+ if (auto *CD = dyn_cast<ClassDecl>(D)) {
327
+ for (auto *SD = CD->getSuperclassDecl (); SD;
328
+ SD = SD->getSuperclassDecl ()) {
329
+ map[nominal].recordProtocols (SD->getInherited (), SD);
330
+ for (auto *Ext: SD->getExtensions ()) {
331
+ if (shouldInclude (Ext)) {
332
+ map[nominal].recordProtocols (Ext->getInherited (), Ext);
333
+ }
334
+ }
335
+ }
336
+ }
320
337
321
338
// Recurse to find any nested types.
322
339
for (const Decl *member : memberContext->getMembers ())
0 commit comments