@@ -126,6 +126,31 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
126
126
// Decls
127
127
// ===--------------------------------------------------------------------===//
128
128
129
+ bool visitGenericParamListIfNeeded (GenericContext *GC) {
130
+ // Must check this first in case extensions have not been bound yet
131
+ if (Walker.shouldWalkIntoGenericParams ()) {
132
+ if (auto *params = GC->getGenericParams ()) {
133
+ visitGenericParamList (params);
134
+ }
135
+ return true ;
136
+ }
137
+ return false ;
138
+ }
139
+
140
+ bool visitTrailingRequirements (GenericContext *GC) {
141
+ if (const auto Where = GC->getTrailingWhereClause ()) {
142
+ for (auto &Req: Where->getRequirements ())
143
+ if (doIt (Req))
144
+ return true ;
145
+ } else if (!isa<ExtensionDecl>(GC)) {
146
+ if (const auto GP = GC->getGenericParams ())
147
+ for (auto Req: GP->getTrailingRequirements ())
148
+ if (doIt (Req))
149
+ return true ;
150
+ }
151
+ return false ;
152
+ }
153
+
129
154
bool visitImportDecl (ImportDecl *ID) {
130
155
return false ;
131
156
}
@@ -138,12 +163,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
138
163
if (doIt (Inherit))
139
164
return true ;
140
165
}
141
- if (auto *Where = ED->getTrailingWhereClause ()) {
142
- for (auto &Req: Where->getRequirements ()) {
143
- if (doIt (Req))
144
- return true ;
145
- }
146
- }
166
+ if (visitTrailingRequirements (ED))
167
+ return true ;
168
+
147
169
for (Decl *M : ED->getMembers ()) {
148
170
if (doIt (M))
149
171
return true ;
@@ -223,15 +245,13 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
223
245
}
224
246
225
247
bool visitTypeAliasDecl (TypeAliasDecl *TAD) {
226
- if (Walker.shouldWalkIntoGenericParams () && TAD->getGenericParams ()) {
227
- if (visitGenericParamList (TAD->getGenericParams ()))
228
- return true ;
229
- }
248
+ bool WalkGenerics = visitGenericParamListIfNeeded (TAD);
230
249
231
250
if (auto typerepr = TAD->getUnderlyingTypeRepr ())
232
251
if (doIt (typerepr))
233
252
return true ;
234
- return false ;
253
+
254
+ return WalkGenerics && visitTrailingRequirements (TAD);
235
255
}
236
256
237
257
bool visitOpaqueTypeDecl (OpaqueTypeDecl *OTD) {
@@ -269,20 +289,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
269
289
}
270
290
271
291
// Visit requirements
272
- if (WalkGenerics) {
273
- ArrayRef<swift::RequirementRepr> Reqs = None;
274
- if (auto *Protocol = dyn_cast<ProtocolDecl>(NTD)) {
275
- if (auto *WhereClause = Protocol->getTrailingWhereClause ())
276
- Reqs = WhereClause->getRequirements ();
277
- } else {
278
- Reqs = NTD->getGenericParams ()->getTrailingRequirements ();
279
- }
280
- for (auto Req: Reqs) {
281
- if (doIt (Req))
282
- return true ;
283
- }
284
- }
285
-
292
+ if (WalkGenerics && visitTrailingRequirements (NTD))
293
+ return true ;
294
+
286
295
for (Decl *Member : NTD->getMembers ()) {
287
296
if (doIt (Member))
288
297
return true ;
@@ -325,13 +334,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
325
334
if (doIt (SD->getElementTypeLoc ()))
326
335
return true ;
327
336
328
- if (WalkGenerics) {
329
- // Visit generic requirements
330
- for (auto Req : SD->getGenericParams ()->getTrailingRequirements ()) {
331
- if (doIt (Req))
332
- return true ;
333
- }
334
- }
337
+ // Visit trailing requirements
338
+ if (WalkGenerics && visitTrailingRequirements (SD))
339
+ return true ;
335
340
336
341
if (!Walker.shouldWalkAccessorsTheOldWay ()) {
337
342
for (auto *AD : SD->getAllAccessors ())
@@ -364,13 +369,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
364
369
if (doIt (FD->getBodyResultTypeLoc ()))
365
370
return true ;
366
371
367
- if (WalkGenerics) {
368
- // Visit trailing requirments
369
- for (auto Req : AFD->getGenericParams ()->getTrailingRequirements ()) {
370
- if (doIt (Req))
371
- return true ;
372
- }
373
- }
372
+ // Visit trailing requirements
373
+ if (WalkGenerics && visitTrailingRequirements (AFD))
374
+ return true ;
374
375
375
376
if (AFD->getBody (/* canSynthesize=*/ false )) {
376
377
AbstractFunctionDecl::BodyKind PreservedKind = AFD->getBodyKind ();
@@ -1323,17 +1324,6 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
1323
1324
}
1324
1325
return false ;
1325
1326
}
1326
-
1327
- private:
1328
- bool visitGenericParamListIfNeeded (GenericContext *gc) {
1329
- if (Walker.shouldWalkIntoGenericParams ()) {
1330
- if (auto *params = gc->getGenericParams ()) {
1331
- visitGenericParamList (params);
1332
- return true ;
1333
- }
1334
- }
1335
- return false ;
1336
- }
1337
1327
};
1338
1328
1339
1329
} // end anonymous namespace
0 commit comments