Skip to content

Commit c498ad0

Browse files
committed
IDE: Ensure syntax coloring for contextual where clauses
1 parent 797805a commit c498ad0

File tree

2 files changed

+59
-50
lines changed

2 files changed

+59
-50
lines changed

lib/AST/ASTWalker.cpp

Lines changed: 40 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,31 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
126126
// Decls
127127
//===--------------------------------------------------------------------===//
128128

129+
bool visitGenericParamListIfNeeded(GenericContext *GC) {
130+
// Must check this first in case extensions have not been bound yet
131+
if (Walker.shouldWalkIntoGenericParams()) {
132+
if (auto *params = GC->getGenericParams()) {
133+
visitGenericParamList(params);
134+
}
135+
return true;
136+
}
137+
return false;
138+
}
139+
140+
bool visitTrailingRequirements(GenericContext *GC) {
141+
if (const auto Where = GC->getTrailingWhereClause()) {
142+
for (auto &Req: Where->getRequirements())
143+
if (doIt(Req))
144+
return true;
145+
} else if (!isa<ExtensionDecl>(GC)) {
146+
if (const auto GP = GC->getGenericParams())
147+
for (auto Req: GP->getTrailingRequirements())
148+
if (doIt(Req))
149+
return true;
150+
}
151+
return false;
152+
}
153+
129154
bool visitImportDecl(ImportDecl *ID) {
130155
return false;
131156
}
@@ -138,12 +163,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
138163
if (doIt(Inherit))
139164
return true;
140165
}
141-
if (auto *Where = ED->getTrailingWhereClause()) {
142-
for(auto &Req: Where->getRequirements()) {
143-
if (doIt(Req))
144-
return true;
145-
}
146-
}
166+
if (visitTrailingRequirements(ED))
167+
return true;
168+
147169
for (Decl *M : ED->getMembers()) {
148170
if (doIt(M))
149171
return true;
@@ -223,15 +245,13 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
223245
}
224246

225247
bool visitTypeAliasDecl(TypeAliasDecl *TAD) {
226-
if (Walker.shouldWalkIntoGenericParams() && TAD->getGenericParams()) {
227-
if (visitGenericParamList(TAD->getGenericParams()))
228-
return true;
229-
}
248+
bool WalkGenerics = visitGenericParamListIfNeeded(TAD);
230249

231250
if (auto typerepr = TAD->getUnderlyingTypeRepr())
232251
if (doIt(typerepr))
233252
return true;
234-
return false;
253+
254+
return WalkGenerics && visitTrailingRequirements(TAD);
235255
}
236256

237257
bool visitOpaqueTypeDecl(OpaqueTypeDecl *OTD) {
@@ -269,20 +289,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
269289
}
270290

271291
// Visit requirements
272-
if (WalkGenerics) {
273-
ArrayRef<swift::RequirementRepr> Reqs = None;
274-
if (auto *Protocol = dyn_cast<ProtocolDecl>(NTD)) {
275-
if (auto *WhereClause = Protocol->getTrailingWhereClause())
276-
Reqs = WhereClause->getRequirements();
277-
} else {
278-
Reqs = NTD->getGenericParams()->getTrailingRequirements();
279-
}
280-
for (auto Req: Reqs) {
281-
if (doIt(Req))
282-
return true;
283-
}
284-
}
285-
292+
if (WalkGenerics && visitTrailingRequirements(NTD))
293+
return true;
294+
286295
for (Decl *Member : NTD->getMembers()) {
287296
if (doIt(Member))
288297
return true;
@@ -325,13 +334,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
325334
if (doIt(SD->getElementTypeLoc()))
326335
return true;
327336

328-
if (WalkGenerics) {
329-
// Visit generic requirements
330-
for (auto Req : SD->getGenericParams()->getTrailingRequirements()) {
331-
if (doIt(Req))
332-
return true;
333-
}
334-
}
337+
// Visit trailing requirements
338+
if (WalkGenerics && visitTrailingRequirements(SD))
339+
return true;
335340

336341
if (!Walker.shouldWalkAccessorsTheOldWay()) {
337342
for (auto *AD : SD->getAllAccessors())
@@ -364,13 +369,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
364369
if (doIt(FD->getBodyResultTypeLoc()))
365370
return true;
366371

367-
if (WalkGenerics) {
368-
// Visit trailing requirments
369-
for (auto Req : AFD->getGenericParams()->getTrailingRequirements()) {
370-
if (doIt(Req))
371-
return true;
372-
}
373-
}
372+
// Visit trailing requirements
373+
if (WalkGenerics && visitTrailingRequirements(AFD))
374+
return true;
374375

375376
if (AFD->getBody(/*canSynthesize=*/false)) {
376377
AbstractFunctionDecl::BodyKind PreservedKind = AFD->getBodyKind();
@@ -1323,17 +1324,6 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
13231324
}
13241325
return false;
13251326
}
1326-
1327-
private:
1328-
bool visitGenericParamListIfNeeded(GenericContext *gc) {
1329-
if (Walker.shouldWalkIntoGenericParams()) {
1330-
if (auto *params = gc->getGenericParams()) {
1331-
visitGenericParamList(params);
1332-
return true;
1333-
}
1334-
}
1335-
return false;
1336-
}
13371327
};
13381328

13391329
} // end anonymous namespace

test/IDE/coloring.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,3 +522,22 @@ enum E {
522522
// CHECK: <kw>var</kw> <kw>_</kw> = <int>10</int>
523523
@available(iOS 99, *)
524524
var _ = 10
525+
526+
// CHECK: <type>Array</type><<type>T</type>> <kw>where</kw> <type>T</type>: <type>Equatable</type>
527+
typealias GenericAlias<T> = Array<T> where T: Equatable
528+
529+
// Where clauses on contextually generic declarations
530+
//
531+
struct FreeWhere<T> {
532+
// CHECK: <kw>func</kw> foo() <kw>where</kw> <type>T</type> == <type>Int</type>
533+
func foo() where T == Int {}
534+
535+
// CHECK: <kw>subscript</kw>() -> <type>Int</type> <kw>where</kw> <type>T</type>: <type>Sequence</type>
536+
subscript() -> Int where T: Sequence {}
537+
538+
// CHECK: <kw>enum</kw> Enum <kw>where</kw> <type>T</type> == <type>Int</type>
539+
enum Enum where T == Int {}
540+
541+
// CHECK: <kw>typealias</kw> Alias = <type>Int</type> <kw>where</kw> <type>T</type> == <type>Int</type>
542+
typealias Alias = Int where T == Int
543+
}

0 commit comments

Comments
 (0)