@@ -372,7 +372,64 @@ SyntacticElementTarget::walk(ASTWalker &walker) const {
372
372
break ;
373
373
}
374
374
case Kind::forEachStmt: {
375
- if (auto *newStmt = getAsForEachStmt ()->walk (walker)) {
375
+ // We need to skip the where clause if requested, and we currently do not
376
+ // type-check a for loop's BraceStmt as part of the SyntacticElementTarget,
377
+ // so we need to skip it here.
378
+ // TODO: We ought to be able to fold BraceStmt checking into the constraint
379
+ // system eventually.
380
+ class ForEachWalker : public ASTWalker {
381
+ ASTWalker &Walker;
382
+ SyntacticElementTarget Target;
383
+ ForEachStmt *ForStmt;
384
+
385
+ public:
386
+ ForEachWalker (ASTWalker &walker, SyntacticElementTarget target)
387
+ : Walker(walker), Target(target), ForStmt(target.getAsForEachStmt()) {}
388
+
389
+ PreWalkAction walkToDeclPre (Decl *D) override {
390
+ if (D->walk (Walker))
391
+ return Action::Stop ();
392
+ return Action::SkipNode ();
393
+ }
394
+
395
+ PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
396
+ // Ignore where clause if needed.
397
+ if (Target.ignoreForEachWhereClause () && E == ForStmt->getWhere ())
398
+ return Action::SkipNode (E);
399
+
400
+ E = E->walk (Walker);
401
+
402
+ if (!E)
403
+ return Action::Stop ();
404
+ return Action::SkipNode (E);
405
+ }
406
+
407
+ PreWalkResult<Stmt *> walkToStmtPre (Stmt *S) override {
408
+ // We only want to visit the children of the ForEachStmt.
409
+ if (S == ForStmt)
410
+ return Action::Continue (S);
411
+
412
+ // But not its body.
413
+ if (S != ForStmt->getBody ())
414
+ S = S->walk (Walker);
415
+
416
+ if (!S)
417
+ return Action::Stop ();
418
+
419
+ return Action::SkipNode (S);
420
+ }
421
+
422
+ PreWalkResult<Pattern *> walkToPatternPre (Pattern *P) override {
423
+ P = P->walk (Walker);
424
+ if (!P)
425
+ return Action::Stop ();
426
+ return Action::SkipNode (P);
427
+ }
428
+ };
429
+
430
+ ForEachWalker forEachWalker (walker, *this );
431
+
432
+ if (auto *newStmt = getAsForEachStmt ()->walk (forEachWalker)) {
376
433
result.forEachStmt .stmt = cast<ForEachStmt>(newStmt);
377
434
} else {
378
435
return std::nullopt;
0 commit comments