@@ -227,6 +227,7 @@ void RewriteSystem::initialize(std::vector<std::pair<Term, Term>> &&rules,
227
227
ProtocolGraph &&graph) {
228
228
Protos = graph;
229
229
230
+ // FIXME: Probably this sort is not necessary
230
231
std::sort (rules.begin (), rules.end (),
231
232
[&](std::pair<Term, Term> lhs,
232
233
std::pair<Term, Term> rhs) -> int {
@@ -259,6 +260,14 @@ bool RewriteSystem::addRule(Term lhs, Term rhs) {
259
260
unsigned i = Rules.size ();
260
261
Rules.emplace_back (lhs, rhs);
261
262
263
+ if (lhs.size () == rhs.size () &&
264
+ std::equal (lhs.begin (), lhs.end () - 1 , rhs.begin ()) &&
265
+ lhs.back ().getKind () == Atom::Kind::AssociatedType &&
266
+ rhs.back ().getKind () == Atom::Kind::AssociatedType &&
267
+ lhs.back ().getName () == rhs.back ().getName ()) {
268
+ MergedAssociatedTypes.emplace_back (lhs, rhs);
269
+ }
270
+
262
271
for (unsigned j : indices (Rules)) {
263
272
if (i == j)
264
273
continue ;
@@ -310,6 +319,110 @@ bool RewriteSystem::simplify(Term &term) const {
310
319
return changed;
311
320
}
312
321
322
+ Atom RewriteSystem::mergeAssociatedTypes (Atom lhs, Atom rhs) const {
323
+ assert (lhs.getKind () == Atom::Kind::AssociatedType);
324
+ assert (rhs.getKind () == Atom::Kind::AssociatedType);
325
+ assert (lhs.getName () == rhs.getName ());
326
+ assert (lhs.compare (rhs, Protos) > 0 );
327
+
328
+ auto protos = lhs.getProtocols ();
329
+ auto otherProtos = rhs.getProtocols ();
330
+
331
+ // Follows from lhs > rhs
332
+ assert (protos.size () <= otherProtos.size ());
333
+
334
+ llvm::TinyPtrVector<const ProtocolDecl *> newProtos;
335
+ std::merge (protos.begin (), protos.end (),
336
+ otherProtos.begin (), otherProtos.end (),
337
+ std::back_inserter (newProtos),
338
+ [&](const ProtocolDecl *lhs,
339
+ const ProtocolDecl *rhs) -> int {
340
+ return Protos.compareProtocols (lhs, rhs) < 0 ;
341
+ });
342
+
343
+ llvm::TinyPtrVector<const ProtocolDecl *> minimalProtos;
344
+ for (const auto *newProto : newProtos) {
345
+ auto inheritsFrom = [&](const ProtocolDecl *thisProto) {
346
+ return (thisProto == newProto ||
347
+ Protos.inheritsFrom (thisProto, newProto));
348
+ };
349
+
350
+ if (std::find_if (protos.begin (), protos.end (), inheritsFrom)
351
+ == protos.end ()) {
352
+ minimalProtos.push_back (newProto);
353
+ }
354
+ }
355
+
356
+ assert (minimalProtos.size () >= protos.size ());
357
+ assert (minimalProtos.size () >= otherProtos.size ());
358
+
359
+ return Atom::forAssociatedType (minimalProtos, lhs.getName ());
360
+ }
361
+
362
+ void RewriteSystem::processMergedAssociatedTypes () {
363
+ if (MergedAssociatedTypes.empty ())
364
+ return ;
365
+
366
+ unsigned i = 0 ;
367
+ while (i < MergedAssociatedTypes.size ()) {
368
+ auto pair = MergedAssociatedTypes[i++];
369
+ const auto &lhs = pair.first ;
370
+ const auto &rhs = pair.second ;
371
+
372
+ // If we have ...[P1:T] => ...[P2:T], add a new pair of rules:
373
+ // ...[P1:T] => ...[P1&P2:T]
374
+ // ...[P2:T] => ...[P1&P2:T]
375
+ if (DebugMerge) {
376
+ llvm::dbgs () << " ## Associated type merge candidate " ;
377
+ lhs.dump (llvm::dbgs ());
378
+ llvm::dbgs () << " => " ;
379
+ rhs.dump (llvm::dbgs ());
380
+ llvm::dbgs () << " \n " ;
381
+ }
382
+
383
+ auto mergedAtom = mergeAssociatedTypes (lhs.back (), rhs.back ());
384
+ if (DebugMerge) {
385
+ llvm::dbgs () << " ### Merged atom " ;
386
+ mergedAtom.dump (llvm::dbgs ());
387
+ llvm::dbgs () << " \n " ;
388
+ }
389
+
390
+ Term mergedTerm = lhs;
391
+ mergedTerm.back () = mergedAtom;
392
+
393
+ addRule (lhs, mergedTerm);
394
+ addRule (rhs, mergedTerm);
395
+
396
+ for (const auto &otherRule : Rules) {
397
+ const auto &otherLHS = otherRule.getLHS ();
398
+ if (otherLHS.size () == 2 &&
399
+ otherLHS[1 ].getKind () == Atom::Kind::Protocol) {
400
+ if (otherLHS[0 ] == lhs.back () ||
401
+ otherLHS[0 ] == rhs.back ()) {
402
+ if (DebugMerge) {
403
+ llvm::dbgs () << " ### Lifting conformance rule " ;
404
+ otherRule.dump (llvm::dbgs ());
405
+ llvm::dbgs () << " \n " ;
406
+ }
407
+
408
+ auto otherRHS = otherRule.getRHS ();
409
+ assert (otherRHS.size () == 1 );
410
+ assert (otherRHS[0 ] == otherLHS[0 ]);
411
+
412
+ otherRHS.back () = mergedAtom;
413
+
414
+ auto newLHS = otherRHS;
415
+ newLHS.add (Atom::forProtocol (otherLHS[1 ].getProtocol ()));
416
+
417
+ addRule (newLHS, otherRHS);
418
+ }
419
+ }
420
+ }
421
+ }
422
+
423
+ MergedAssociatedTypes.clear ();
424
+ }
425
+
313
426
RewriteSystem::CompletionResult
314
427
RewriteSystem::computeConfluentCompletion (unsigned maxIterations,
315
428
unsigned maxDepth) {
@@ -361,8 +474,23 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
361
474
if (rule.canReduceLeftHandSide (newRule))
362
475
rule.markDeleted ();
363
476
}
477
+
478
+ processMergedAssociatedTypes ();
364
479
}
365
480
481
+ // This isn't necessary for correctness, it's just an optimization.
482
+ for (auto &rule : Rules) {
483
+ auto rhs = rule.getRHS ();
484
+ simplify (rhs);
485
+ rule = Rule (rule.getLHS (), rhs);
486
+ }
487
+
488
+ // Just for aesthetics in dump().
489
+ std::sort (Rules.begin (), Rules.end (),
490
+ [&](Rule lhs, Rule rhs) -> int {
491
+ return lhs.getLHS ().compare (rhs.getLHS (), Protos) < 0 ;
492
+ });
493
+
366
494
return CompletionResult::Success;
367
495
}
368
496
0 commit comments