@@ -239,6 +239,12 @@ class RewriteTreeNode {
239
239
rewrite = replacementPath;
240
240
}
241
241
242
+ // / Remove the rewrite rule.
243
+ void removeRewriteRule () {
244
+ assert (hasRewriteRule ());
245
+ assocTypeAndHasRewrite.setInt (false );
246
+ }
247
+
242
248
// / Retrieve the path to which this node will be rewritten.
243
249
const RewritePath &getRewriteRule () const {
244
250
assert (hasRewriteRule ());
@@ -289,16 +295,67 @@ class RewriteTreeNode {
289
295
// / Merge the given rewrite tree into \c other.
290
296
void mergeInto (RewriteTreeNode *other);
291
297
298
+ // / An action to perform for the given rule
299
+ class RuleAction {
300
+ enum Kind {
301
+ // / No action; continue traversal.
302
+ None,
303
+
304
+ // / Stop traversal.
305
+ Stop,
306
+
307
+ // / Remove the given rule completely.
308
+ Remove,
309
+
310
+ // / Replace the right-hand side of the rule with the given new path.
311
+ Replace,
312
+ } kind;
313
+
314
+ RewritePath path;
315
+
316
+ RuleAction (Kind kind, RewritePath path = {})
317
+ : kind(kind), path(path) { }
318
+
319
+ friend class RewriteTreeNode ;
320
+
321
+ public:
322
+ static RuleAction none () { return RuleAction (None); }
323
+ static RuleAction stop () { return RuleAction (Stop); }
324
+ static RuleAction remove () { return RuleAction (Remove); }
325
+
326
+ static RuleAction replace (RewritePath path) {
327
+ return RuleAction (Replace, std::move (path));
328
+ }
329
+
330
+ operator Kind () const { return kind; }
331
+ };
332
+
333
+ // / Callback function for enumerating rules in a tree.
334
+ using EnumerateCallback =
335
+ RuleAction (RelativeRewritePath lhs, const RewritePath &rhs);
336
+
337
+ // / Enumerate all of the rewrite rules, calling \c fn with the left and
338
+ // / right-hand sides of each rule.
339
+ // /
340
+ // / \returns true if the action function returned \c Stop at any point.
341
+ bool enumerateRules (llvm::function_ref<EnumerateCallback> fn) {
342
+ SmallVector<AssociatedTypeDecl *, 4 > lhs;
343
+ return enumerateRulesRec (fn, lhs);
344
+ }
345
+
292
346
LLVM_ATTRIBUTE_DEPRECATED (void dump () const LLVM_ATTRIBUTE_USED,
293
347
"only for use within the debugger");
294
348
295
349
// / Dump the tree.
296
350
void dump (llvm::raw_ostream &out, bool lastChild = true ) const ;
297
351
298
352
private:
299
- // / Merge the given rewrite tree into \c other.
300
- void mergeIntoRec (RewriteTreeNode *other,
301
- llvm::SmallVectorImpl<AssociatedTypeDecl *> &matchPath);
353
+ // / Enumerate all of the rewrite rules, calling \c fn with the left and
354
+ // / right-hand sides of each rule.
355
+ // /
356
+ // / \returns true if the action function returned \c Stop at any point.
357
+ bool enumerateRulesRec (llvm::function_ref<EnumerateCallback> &fn,
358
+ llvm::SmallVectorImpl<AssociatedTypeDecl *> &lhs);
302
359
};
303
360
}
304
361
@@ -3268,28 +3325,53 @@ RewriteTreeNode::bestMatch(GenericParamKey base, RelativeRewritePath path,
3268
3325
}
3269
3326
3270
3327
void RewriteTreeNode::mergeInto (RewriteTreeNode *other) {
3271
- SmallVector<AssociatedTypeDecl *, 4 > matchPath;
3272
- mergeIntoRec (other, matchPath);
3273
- }
3274
-
3275
- void RewriteTreeNode::mergeIntoRec (
3276
- RewriteTreeNode *other,
3277
- llvm::SmallVectorImpl<AssociatedTypeDecl *> &matchPath) {
3278
3328
// FIXME: A destructive version of this operation would be more efficient,
3279
3329
// since we generally don't care about \c other after doing this.
3330
+ (void )enumerateRules ([other](RelativeRewritePath lhs,
3331
+ const RewritePath &rhs) {
3332
+ other->addRewriteRule (lhs, rhs);
3333
+ return RuleAction::none ();
3334
+ });
3335
+ }
3336
+
3337
+ bool RewriteTreeNode::enumerateRulesRec (
3338
+ llvm::function_ref<EnumerateCallback> &fn,
3339
+ llvm::SmallVectorImpl<AssociatedTypeDecl *> &lhs) {
3280
3340
if (auto assocType = getMatch ())
3281
- matchPath.push_back (assocType);
3341
+ lhs.push_back (assocType);
3342
+
3343
+ SWIFT_DEFER {
3344
+ if (auto assocType = getMatch ())
3345
+ lhs.pop_back ();
3346
+ };
3347
+
3348
+ // If there is a rewrite rule, invoke the callback.
3349
+ if (hasRewriteRule ()) {
3350
+ switch (RuleAction action = fn (lhs, getRewriteRule ())) {
3351
+ case RuleAction::None:
3352
+ break ;
3282
3353
3283
- // Add this rewrite rule, if there is one.
3284
- if (hasRewriteRule ())
3285
- other->addRewriteRule (matchPath, rewrite);
3354
+ case RuleAction::Stop:
3355
+ return true ;
3356
+
3357
+ case RuleAction::Remove:
3358
+ removeRewriteRule ();
3359
+ break ;
3360
+
3361
+ case RuleAction::Replace:
3362
+ removeRewriteRule ();
3363
+ setRewriteRule (action.path );
3364
+ break ;
3365
+ }
3366
+ }
3286
3367
3287
3368
// Recurse into the child nodes.
3288
- for (auto child : children)
3289
- child->mergeIntoRec (other, matchPath);
3369
+ for (auto child : children) {
3370
+ if (child->enumerateRulesRec (fn, lhs))
3371
+ return true ;
3372
+ }
3290
3373
3291
- if (getMatch ())
3292
- matchPath.pop_back ();
3374
+ return false ;
3293
3375
}
3294
3376
3295
3377
void RewriteTreeNode::dump () const {
0 commit comments