Skip to content

Commit 4d7542c

Browse files
authored
Merge pull request swiftlang#21226 from slavapestov/invalid-extensions
Fix two crash-on-invalid involving extensions
2 parents 191a71e + 4c0b391 commit 4d7542c

File tree

8 files changed

+213
-208
lines changed

8 files changed

+213
-208
lines changed

include/swift/AST/Decl.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,13 +1422,10 @@ class GenericParamList final :
14221422
Requirements.back().getSourceRange().End);
14231423
}
14241424

1425-
/// Retrieve the depth of this generic parameter list.
1426-
unsigned getDepth() const {
1427-
unsigned depth = 0;
1428-
for (auto gp = getOuterParameters(); gp; gp = gp->getOuterParameters())
1429-
++depth;
1430-
return depth;
1431-
}
1425+
unsigned getDepth() const;
1426+
1427+
/// Configure the depth of the generic parameters in this list.
1428+
void configureGenericParamDepth();
14321429

14331430
/// Create a copy of the generic parameter list and all of its generic
14341431
/// parameter declarations. The copied generic parameters are re-parented
@@ -1764,6 +1761,8 @@ class ExtensionDecl final : public GenericContext, public Decl,
17641761
return getValidationState() > ValidationState::CheckingWithValidSignature;
17651762
}
17661763

1764+
void createGenericParamsIfMissing(NominalTypeDecl *nominal);
1765+
17671766
bool hasDefaultAccessLevel() const {
17681767
return Bits.ExtensionDecl.DefaultAndMaxAccessLevel != 0;
17691768
}

lib/AST/Decl.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,21 @@ void GenericParamList::addTrailingWhereClause(
687687
Requirements = newRequirements;
688688
}
689689

690+
unsigned GenericParamList::getDepth() const {
691+
unsigned depth = 0;
692+
for (auto gpList = getOuterParameters();
693+
gpList != nullptr;
694+
gpList = gpList->getOuterParameters())
695+
++depth;
696+
return depth;
697+
}
698+
699+
void GenericParamList::configureGenericParamDepth() {
700+
unsigned depth = getDepth();
701+
for (auto param : *this)
702+
param->setDepth(depth);
703+
}
704+
690705
TrailingWhereClause::TrailingWhereClause(
691706
SourceLoc whereLoc,
692707
ArrayRef<RequirementRepr> requirements)
@@ -1041,6 +1056,113 @@ AccessLevel ExtensionDecl::getMaxAccessLevel() const {
10411056
{AccessLevel::Private, AccessLevel::Private}).second;
10421057
}
10431058

1059+
/// Clone the given generic parameters in the given list. We don't need any
1060+
/// of the requirements, because they will be inferred.
1061+
static GenericParamList *cloneGenericParams(ASTContext &ctx,
1062+
DeclContext *dc,
1063+
GenericParamList *fromParams) {
1064+
// Clone generic parameters.
1065+
SmallVector<GenericTypeParamDecl *, 2> toGenericParams;
1066+
for (auto fromGP : *fromParams) {
1067+
// Create the new generic parameter.
1068+
auto toGP = new (ctx) GenericTypeParamDecl(dc, fromGP->getName(),
1069+
SourceLoc(),
1070+
fromGP->getDepth(),
1071+
fromGP->getIndex());
1072+
toGP->setImplicit(true);
1073+
1074+
// Record new generic parameter.
1075+
toGenericParams.push_back(toGP);
1076+
}
1077+
1078+
auto toParams = GenericParamList::create(ctx, SourceLoc(), toGenericParams,
1079+
SourceLoc());
1080+
1081+
auto outerParams = fromParams->getOuterParameters();
1082+
if (outerParams != nullptr)
1083+
outerParams = cloneGenericParams(ctx, dc, outerParams);
1084+
toParams->setOuterParameters(outerParams);
1085+
1086+
return toParams;
1087+
}
1088+
1089+
/// Ensure that the outer generic parameters of the given generic
1090+
/// context have been configured.
1091+
static void configureOuterGenericParams(const GenericContext *dc) {
1092+
auto genericParams = dc->getGenericParams();
1093+
1094+
// If we already configured the outer parameters, we're done.
1095+
if (genericParams && genericParams->getOuterParameters())
1096+
return;
1097+
1098+
DeclContext *outerDC = dc->getParent();
1099+
while (!outerDC->isModuleScopeContext()) {
1100+
if (auto outerDecl = outerDC->getAsDecl()) {
1101+
if (auto outerGenericDC = outerDecl->getAsGenericContext()) {
1102+
if (genericParams)
1103+
genericParams->setOuterParameters(outerGenericDC->getGenericParams());
1104+
1105+
configureOuterGenericParams(outerGenericDC);
1106+
return;
1107+
}
1108+
}
1109+
1110+
outerDC = outerDC->getParent();
1111+
}
1112+
}
1113+
1114+
void ExtensionDecl::createGenericParamsIfMissing(NominalTypeDecl *nominal) {
1115+
if (getGenericParams())
1116+
return;
1117+
1118+
// Hack to force generic parameter lists of protocols to be created if the
1119+
// nominal is an (invalid) nested type of a protocol.
1120+
DeclContext *outerDC = nominal;
1121+
while (!outerDC->isModuleScopeContext()) {
1122+
if (auto *proto = dyn_cast<ProtocolDecl>(outerDC))
1123+
proto->createGenericParamsIfMissing();
1124+
1125+
outerDC = outerDC->getParent();
1126+
}
1127+
1128+
configureOuterGenericParams(nominal);
1129+
1130+
if (auto proto = dyn_cast<ProtocolDecl>(nominal)) {
1131+
// For a protocol extension, build the generic parameter list directly
1132+
// since we want it to have an inheritance clause.
1133+
setGenericParams(proto->createGenericParams(this));
1134+
} else if (auto genericParams = nominal->getGenericParamsOfContext()) {
1135+
// Clone the generic parameter list of a generic type.
1136+
setGenericParams(
1137+
cloneGenericParams(getASTContext(), this, genericParams));
1138+
}
1139+
1140+
// Set the depth of every generic parameter.
1141+
auto *genericParams = getGenericParams();
1142+
for (auto *outerParams = genericParams;
1143+
outerParams != nullptr;
1144+
outerParams = outerParams->getOuterParameters())
1145+
outerParams->configureGenericParamDepth();
1146+
1147+
// If we have a trailing where clause, deal with it now.
1148+
// For now, trailing where clauses are only permitted on protocol extensions.
1149+
if (auto trailingWhereClause = getTrailingWhereClause()) {
1150+
if (genericParams) {
1151+
// Merge the trailing where clause into the generic parameter list.
1152+
// FIXME: Long-term, we'd like clients to deal with the trailing where
1153+
// clause explicitly, but for now it's far more direct to represent
1154+
// the trailing where clause as part of the requirements.
1155+
genericParams->addTrailingWhereClause(
1156+
getASTContext(),
1157+
trailingWhereClause->getWhereLoc(),
1158+
trailingWhereClause->getRequirements());
1159+
}
1160+
1161+
// If there's no generic parameter list, the where clause is diagnosed
1162+
// in typeCheckDecl().
1163+
}
1164+
}
1165+
10441166
PatternBindingDecl::PatternBindingDecl(SourceLoc StaticLoc,
10451167
StaticSpellingKind StaticSpelling,
10461168
SourceLoc VarLoc,

0 commit comments

Comments
 (0)