Skip to content

Commit c4a23c4

Browse files
authored
Merge pull request #21245 from slavapestov/invalid-extensions-5.0
Fix two crash-on-invalid involving extensions [5.0]
2 parents f6d3325 + ecba3a5 commit c4a23c4

File tree

8 files changed

+213
-208
lines changed

8 files changed

+213
-208
lines changed

include/swift/AST/Decl.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,13 +1417,10 @@ class GenericParamList final :
14171417
Requirements.back().getSourceRange().End);
14181418
}
14191419

1420-
/// Retrieve the depth of this generic parameter list.
1421-
unsigned getDepth() const {
1422-
unsigned depth = 0;
1423-
for (auto gp = getOuterParameters(); gp; gp = gp->getOuterParameters())
1424-
++depth;
1425-
return depth;
1426-
}
1420+
unsigned getDepth() const;
1421+
1422+
/// Configure the depth of the generic parameters in this list.
1423+
void configureGenericParamDepth();
14271424

14281425
/// Create a copy of the generic parameter list and all of its generic
14291426
/// parameter declarations. The copied generic parameters are re-parented
@@ -1759,6 +1756,8 @@ class ExtensionDecl final : public GenericContext, public Decl,
17591756
return getValidationState() > ValidationState::CheckingWithValidSignature;
17601757
}
17611758

1759+
void createGenericParamsIfMissing(NominalTypeDecl *nominal);
1760+
17621761
bool hasDefaultAccessLevel() const {
17631762
return Bits.ExtensionDecl.DefaultAndMaxAccessLevel != 0;
17641763
}

lib/AST/Decl.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,21 @@ void GenericParamList::addTrailingWhereClause(
689689
Requirements = newRequirements;
690690
}
691691

692+
unsigned GenericParamList::getDepth() const {
693+
unsigned depth = 0;
694+
for (auto gpList = getOuterParameters();
695+
gpList != nullptr;
696+
gpList = gpList->getOuterParameters())
697+
++depth;
698+
return depth;
699+
}
700+
701+
void GenericParamList::configureGenericParamDepth() {
702+
unsigned depth = getDepth();
703+
for (auto param : *this)
704+
param->setDepth(depth);
705+
}
706+
692707
TrailingWhereClause::TrailingWhereClause(
693708
SourceLoc whereLoc,
694709
ArrayRef<RequirementRepr> requirements)
@@ -1043,6 +1058,113 @@ AccessLevel ExtensionDecl::getMaxAccessLevel() const {
10431058
{AccessLevel::Private, AccessLevel::Private}).second;
10441059
}
10451060

1061+
/// Clone the given generic parameters in the given list. We don't need any
1062+
/// of the requirements, because they will be inferred.
1063+
static GenericParamList *cloneGenericParams(ASTContext &ctx,
1064+
DeclContext *dc,
1065+
GenericParamList *fromParams) {
1066+
// Clone generic parameters.
1067+
SmallVector<GenericTypeParamDecl *, 2> toGenericParams;
1068+
for (auto fromGP : *fromParams) {
1069+
// Create the new generic parameter.
1070+
auto toGP = new (ctx) GenericTypeParamDecl(dc, fromGP->getName(),
1071+
SourceLoc(),
1072+
fromGP->getDepth(),
1073+
fromGP->getIndex());
1074+
toGP->setImplicit(true);
1075+
1076+
// Record new generic parameter.
1077+
toGenericParams.push_back(toGP);
1078+
}
1079+
1080+
auto toParams = GenericParamList::create(ctx, SourceLoc(), toGenericParams,
1081+
SourceLoc());
1082+
1083+
auto outerParams = fromParams->getOuterParameters();
1084+
if (outerParams != nullptr)
1085+
outerParams = cloneGenericParams(ctx, dc, outerParams);
1086+
toParams->setOuterParameters(outerParams);
1087+
1088+
return toParams;
1089+
}
1090+
1091+
/// Ensure that the outer generic parameters of the given generic
1092+
/// context have been configured.
1093+
static void configureOuterGenericParams(const GenericContext *dc) {
1094+
auto genericParams = dc->getGenericParams();
1095+
1096+
// If we already configured the outer parameters, we're done.
1097+
if (genericParams && genericParams->getOuterParameters())
1098+
return;
1099+
1100+
DeclContext *outerDC = dc->getParent();
1101+
while (!outerDC->isModuleScopeContext()) {
1102+
if (auto outerDecl = outerDC->getAsDecl()) {
1103+
if (auto outerGenericDC = outerDecl->getAsGenericContext()) {
1104+
if (genericParams)
1105+
genericParams->setOuterParameters(outerGenericDC->getGenericParams());
1106+
1107+
configureOuterGenericParams(outerGenericDC);
1108+
return;
1109+
}
1110+
}
1111+
1112+
outerDC = outerDC->getParent();
1113+
}
1114+
}
1115+
1116+
void ExtensionDecl::createGenericParamsIfMissing(NominalTypeDecl *nominal) {
1117+
if (getGenericParams())
1118+
return;
1119+
1120+
// Hack to force generic parameter lists of protocols to be created if the
1121+
// nominal is an (invalid) nested type of a protocol.
1122+
DeclContext *outerDC = nominal;
1123+
while (!outerDC->isModuleScopeContext()) {
1124+
if (auto *proto = dyn_cast<ProtocolDecl>(outerDC))
1125+
proto->createGenericParamsIfMissing();
1126+
1127+
outerDC = outerDC->getParent();
1128+
}
1129+
1130+
configureOuterGenericParams(nominal);
1131+
1132+
if (auto proto = dyn_cast<ProtocolDecl>(nominal)) {
1133+
// For a protocol extension, build the generic parameter list directly
1134+
// since we want it to have an inheritance clause.
1135+
setGenericParams(proto->createGenericParams(this));
1136+
} else if (auto genericParams = nominal->getGenericParamsOfContext()) {
1137+
// Clone the generic parameter list of a generic type.
1138+
setGenericParams(
1139+
cloneGenericParams(getASTContext(), this, genericParams));
1140+
}
1141+
1142+
// Set the depth of every generic parameter.
1143+
auto *genericParams = getGenericParams();
1144+
for (auto *outerParams = genericParams;
1145+
outerParams != nullptr;
1146+
outerParams = outerParams->getOuterParameters())
1147+
outerParams->configureGenericParamDepth();
1148+
1149+
// If we have a trailing where clause, deal with it now.
1150+
// For now, trailing where clauses are only permitted on protocol extensions.
1151+
if (auto trailingWhereClause = getTrailingWhereClause()) {
1152+
if (genericParams) {
1153+
// Merge the trailing where clause into the generic parameter list.
1154+
// FIXME: Long-term, we'd like clients to deal with the trailing where
1155+
// clause explicitly, but for now it's far more direct to represent
1156+
// the trailing where clause as part of the requirements.
1157+
genericParams->addTrailingWhereClause(
1158+
getASTContext(),
1159+
trailingWhereClause->getWhereLoc(),
1160+
trailingWhereClause->getRequirements());
1161+
}
1162+
1163+
// If there's no generic parameter list, the where clause is diagnosed
1164+
// in typeCheckDecl().
1165+
}
1166+
}
1167+
10461168
PatternBindingDecl::PatternBindingDecl(SourceLoc StaticLoc,
10471169
StaticSpellingKind StaticSpelling,
10481170
SourceLoc VarLoc,

0 commit comments

Comments
 (0)