@@ -227,6 +227,7 @@ void RewriteSystem::initialize(std::vector<std::pair<Term, Term>> &&rules,
227227 ProtocolGraph &&graph) {
228228 Protos = graph;
229229
230+ // FIXME: Probably this sort is not necessary
230231 std::sort (rules.begin (), rules.end (),
231232 [&](std::pair<Term, Term> lhs,
232233 std::pair<Term, Term> rhs) -> int {
@@ -259,6 +260,14 @@ bool RewriteSystem::addRule(Term lhs, Term rhs) {
259260 unsigned i = Rules.size ();
260261 Rules.emplace_back (lhs, rhs);
261262
263+ if (lhs.size () == rhs.size () &&
264+ std::equal (lhs.begin (), lhs.end () - 1 , rhs.begin ()) &&
265+ lhs.back ().getKind () == Atom::Kind::AssociatedType &&
266+ rhs.back ().getKind () == Atom::Kind::AssociatedType &&
267+ lhs.back ().getName () == rhs.back ().getName ()) {
268+ MergedAssociatedTypes.emplace_back (lhs, rhs);
269+ }
270+
262271 for (unsigned j : indices (Rules)) {
263272 if (i == j)
264273 continue ;
@@ -310,6 +319,110 @@ bool RewriteSystem::simplify(Term &term) const {
310319 return changed;
311320}
312321
322+ Atom RewriteSystem::mergeAssociatedTypes (Atom lhs, Atom rhs) const {
323+ assert (lhs.getKind () == Atom::Kind::AssociatedType);
324+ assert (rhs.getKind () == Atom::Kind::AssociatedType);
325+ assert (lhs.getName () == rhs.getName ());
326+ assert (lhs.compare (rhs, Protos) > 0 );
327+
328+ auto protos = lhs.getProtocols ();
329+ auto otherProtos = rhs.getProtocols ();
330+
331+ // Follows from lhs > rhs
332+ assert (protos.size () <= otherProtos.size ());
333+
334+ llvm::TinyPtrVector<const ProtocolDecl *> newProtos;
335+ std::merge (protos.begin (), protos.end (),
336+ otherProtos.begin (), otherProtos.end (),
337+ std::back_inserter (newProtos),
338+ [&](const ProtocolDecl *lhs,
339+ const ProtocolDecl *rhs) -> int {
340+ return Protos.compareProtocols (lhs, rhs) < 0 ;
341+ });
342+
343+ llvm::TinyPtrVector<const ProtocolDecl *> minimalProtos;
344+ for (const auto *newProto : newProtos) {
345+ auto inheritsFrom = [&](const ProtocolDecl *thisProto) {
346+ return (thisProto == newProto ||
347+ Protos.inheritsFrom (thisProto, newProto));
348+ };
349+
350+ if (std::find_if (protos.begin (), protos.end (), inheritsFrom)
351+ == protos.end ()) {
352+ minimalProtos.push_back (newProto);
353+ }
354+ }
355+
356+ assert (minimalProtos.size () >= protos.size ());
357+ assert (minimalProtos.size () >= otherProtos.size ());
358+
359+ return Atom::forAssociatedType (minimalProtos, lhs.getName ());
360+ }
361+
362+ void RewriteSystem::processMergedAssociatedTypes () {
363+ if (MergedAssociatedTypes.empty ())
364+ return ;
365+
366+ unsigned i = 0 ;
367+ while (i < MergedAssociatedTypes.size ()) {
368+ auto pair = MergedAssociatedTypes[i++];
369+ const auto &lhs = pair.first ;
370+ const auto &rhs = pair.second ;
371+
372+ // If we have ...[P1:T] => ...[P2:T], add a new pair of rules:
373+ // ...[P1:T] => ...[P1&P2:T]
374+ // ...[P2:T] => ...[P1&P2:T]
375+ if (DebugMerge) {
376+ llvm::dbgs () << " ## Associated type merge candidate " ;
377+ lhs.dump (llvm::dbgs ());
378+ llvm::dbgs () << " => " ;
379+ rhs.dump (llvm::dbgs ());
380+ llvm::dbgs () << " \n " ;
381+ }
382+
383+ auto mergedAtom = mergeAssociatedTypes (lhs.back (), rhs.back ());
384+ if (DebugMerge) {
385+ llvm::dbgs () << " ### Merged atom " ;
386+ mergedAtom.dump (llvm::dbgs ());
387+ llvm::dbgs () << " \n " ;
388+ }
389+
390+ Term mergedTerm = lhs;
391+ mergedTerm.back () = mergedAtom;
392+
393+ addRule (lhs, mergedTerm);
394+ addRule (rhs, mergedTerm);
395+
396+ for (const auto &otherRule : Rules) {
397+ const auto &otherLHS = otherRule.getLHS ();
398+ if (otherLHS.size () == 2 &&
399+ otherLHS[1 ].getKind () == Atom::Kind::Protocol) {
400+ if (otherLHS[0 ] == lhs.back () ||
401+ otherLHS[0 ] == rhs.back ()) {
402+ if (DebugMerge) {
403+ llvm::dbgs () << " ### Lifting conformance rule " ;
404+ otherRule.dump (llvm::dbgs ());
405+ llvm::dbgs () << " \n " ;
406+ }
407+
408+ auto otherRHS = otherRule.getRHS ();
409+ assert (otherRHS.size () == 1 );
410+ assert (otherRHS[0 ] == otherLHS[0 ]);
411+
412+ otherRHS.back () = mergedAtom;
413+
414+ auto newLHS = otherRHS;
415+ newLHS.add (Atom::forProtocol (otherLHS[1 ].getProtocol ()));
416+
417+ addRule (newLHS, otherRHS);
418+ }
419+ }
420+ }
421+ }
422+
423+ MergedAssociatedTypes.clear ();
424+ }
425+
313426RewriteSystem::CompletionResult
314427RewriteSystem::computeConfluentCompletion (unsigned maxIterations,
315428 unsigned maxDepth) {
@@ -361,8 +474,23 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
361474 if (rule.canReduceLeftHandSide (newRule))
362475 rule.markDeleted ();
363476 }
477+
478+ processMergedAssociatedTypes ();
364479 }
365480
481+ // This isn't necessary for correctness, it's just an optimization.
482+ for (auto &rule : Rules) {
483+ auto rhs = rule.getRHS ();
484+ simplify (rhs);
485+ rule = Rule (rule.getLHS (), rhs);
486+ }
487+
488+ // Just for aesthetics in dump().
489+ std::sort (Rules.begin (), Rules.end (),
490+ [&](Rule lhs, Rule rhs) -> int {
491+ return lhs.getLHS ().compare (rhs.getLHS (), Protos) < 0 ;
492+ });
493+
366494 return CompletionResult::Success;
367495}
368496
0 commit comments