@@ -331,25 +331,36 @@ static void bindExtensionToNominal(ExtensionDecl *ext,
331
331
if (ext->alreadyBoundToNominal ())
332
332
return ;
333
333
334
+ // Hack to force generic parameter lists of protocols to be created if the
335
+ // nominal is an (invalid) nested type of a protocol.
336
+ DeclContext *outerDC = nominal;
337
+ while (!outerDC->isModuleScopeContext ()) {
338
+ if (auto *proto = dyn_cast<ProtocolDecl>(outerDC))
339
+ proto->createGenericParamsIfMissing ();
340
+
341
+ outerDC = outerDC->getParent ();
342
+ }
343
+
344
+ configureOuterGenericParams (nominal);
345
+
334
346
if (auto proto = dyn_cast<ProtocolDecl>(nominal)) {
335
- // For a protocol extension, build the generic parameter list.
336
- auto genericParams = proto->createGenericParams (ext);
337
- prepareGenericParamList (genericParams);
338
- ext->setGenericParams (genericParams);
347
+ // For a protocol extension, build the generic parameter list directly
348
+ // since we want it to have an inheritance clause.
349
+ ext->setGenericParams (proto->createGenericParams (ext));
339
350
} else if (auto genericParams = nominal->getGenericParamsOfContext ()) {
340
- // Make sure the generic parameters are set up.
341
- configureOuterGenericParams (nominal);
342
-
343
351
// Clone the generic parameter list of a generic type.
344
- prepareGenericParamList (genericParams);
345
352
ext->setGenericParams (
346
353
cloneGenericParams (ext->getASTContext (), ext, genericParams));
347
354
}
348
355
356
+ auto *genericParams = ext->getGenericParams ();
357
+ if (genericParams)
358
+ prepareGenericParamList (genericParams);
359
+
349
360
// If we have a trailing where clause, deal with it now.
350
361
// For now, trailing where clauses are only permitted on protocol extensions.
351
362
if (auto trailingWhereClause = ext->getTrailingWhereClause ()) {
352
- if (!(nominal-> getGenericParamsOfContext () || isa<ProtocolDecl>(nominal)) ) {
363
+ if (!genericParams ) {
353
364
// Only generic and protocol types are permitted to have
354
365
// trailing where clauses.
355
366
ext->diagnose (diag::extension_nongeneric_trailing_where,
@@ -361,7 +372,7 @@ static void bindExtensionToNominal(ExtensionDecl *ext,
361
372
// FIXME: Long-term, we'd like clients to deal with the trailing where
362
373
// clause explicitly, but for now it's far more direct to represent
363
374
// the trailing where clause as part of the requirements.
364
- ext-> getGenericParams () ->addTrailingWhereClause (
375
+ genericParams ->addTrailingWhereClause (
365
376
ext->getASTContext (),
366
377
trailingWhereClause->getWhereLoc (),
367
378
trailingWhereClause->getRequirements ());
0 commit comments