@@ -83,8 +83,28 @@ void ConstraintSystem::PotentialBindings::inferTransitiveProtocolRequirements(
8383 for (const auto &entry : bindings.SubtypeOf )
8484 addToWorkList (currentVar, entry.first );
8585
86- for (const auto &entry : bindings.EquivalentTo )
87- addToWorkList (currentVar, entry.first );
86+ // If current type variable is part of an equivalence
87+ // class, make it a "representative" and let's it infer
88+ // supertypes and direct protocol requirements from
89+ // other members.
90+ for (const auto &entry : bindings.EquivalentTo ) {
91+ auto eqBindings = inferredBindings.find (entry.first );
92+ if (eqBindings != inferredBindings.end ()) {
93+ const auto &bindings = eqBindings->getSecond ();
94+
95+ llvm::SmallPtrSet<Constraint *, 2 > placeholder;
96+ // Add any direct protocols from members of the
97+ // equivalence class, so they could be propagated
98+ // to all of the members.
99+ propagateProtocolsTo (currentVar, bindings.Protocols , placeholder);
100+
101+ // Since type variables are equal, current type variable
102+ // becomes a subtype to any supertype found in the current
103+ // equivalence class.
104+ for (const auto &eqEntry : bindings.SubtypeOf )
105+ addToWorkList (currentVar, eqEntry.first );
106+ }
107+ }
88108
89109 // More subtype/equivalences relations have been added.
90110 if (workList.back ().second != currentVar)
@@ -100,9 +120,36 @@ void ConstraintSystem::PotentialBindings::inferTransitiveProtocolRequirements(
100120 propagateProtocolsTo (parent, bindings.Protocols , protocols[currentVar]);
101121 }
102122
123+ auto inferredProtocols = std::move (protocols[currentVar]);
124+
125+ llvm::SmallPtrSet<Constraint *, 4 > protocolsForEquivalence;
126+
127+ // Equivalence class should contain both:
128+ // - direct protocol requirements of the current type
129+ // variable;
130+ // - all of the transitive protocols inferred through
131+ // the members of the equivalence class.
132+ {
133+ protocolsForEquivalence.insert (bindings.Protocols .begin (),
134+ bindings.Protocols .end ());
135+
136+ protocolsForEquivalence.insert (inferredProtocols.begin (),
137+ inferredProtocols.end ());
138+ }
139+
140+ // Propogate inferred protocols to all of the members of the
141+ // equivalence class.
142+ for (const auto &equivalence : bindings.EquivalentTo ) {
143+ auto eqBindings = inferredBindings.find (equivalence.first );
144+ if (eqBindings != inferredBindings.end ()) {
145+ auto &bindings = eqBindings->getSecond ();
146+ bindings.TransitiveProtocols .emplace (protocolsForEquivalence);
147+ }
148+ }
149+
103150 // Update the bindings associated with current type variable,
104151 // to avoid repeating this inference process.
105- bindings.TransitiveProtocols .emplace (std::move (protocols[currentVar] ));
152+ bindings.TransitiveProtocols .emplace (std::move (inferredProtocols ));
106153 } while (!workList.empty ());
107154}
108155
0 commit comments