@@ -205,6 +205,42 @@ AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(Symbol symbol) {
205
205
return assocType;
206
206
}
207
207
208
+ // / Find the most canonical associated type declaration with the given
209
+ // / name among a set of conforming protocols stored in this property map
210
+ // / entry.
211
+ AssociatedTypeDecl *PropertyBag::getAssociatedType (Identifier name) {
212
+ auto found = AssocTypes.find (name);
213
+ if (found != AssocTypes.end ())
214
+ return found->second ;
215
+
216
+ AssociatedTypeDecl *assocType = nullptr ;
217
+
218
+ for (auto *proto : ConformsTo) {
219
+ auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
220
+ otherAssocType = otherAssocType->getAssociatedTypeAnchor ();
221
+
222
+ if (otherAssocType->getName () == name &&
223
+ (assocType == nullptr ||
224
+ TypeDecl::compare (otherAssocType->getProtocol (),
225
+ assocType->getProtocol ()) < 0 )) {
226
+ assocType = otherAssocType;
227
+ }
228
+ };
229
+
230
+ for (auto *otherAssocType : proto->getAssociatedTypeMembers ()) {
231
+ checkOtherAssocType (otherAssocType);
232
+ }
233
+ }
234
+
235
+ assert (assocType != nullptr && " Need to look harder" );
236
+
237
+ auto inserted = AssocTypes.insert (std::make_pair (name, assocType)).second ;
238
+ assert (inserted);
239
+ (void ) inserted;
240
+
241
+ return assocType;
242
+ }
243
+
208
244
// / Compute the interface type for a range of symbols, with an optional
209
245
// / root type.
210
246
// /
@@ -233,8 +269,8 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
233
269
result = genericParam;
234
270
};
235
271
236
- for (; begin != end; ++begin ) {
237
- auto symbol = *begin ;
272
+ for (auto *iter = begin; iter != end; ++iter ) {
273
+ auto symbol = *iter ;
238
274
239
275
if (!result) {
240
276
// A valid term always begins with a generic parameter, protocol or
@@ -253,7 +289,7 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
253
289
handleRoot (GenericTypeParamType::get (/* type sequence*/ false , 0 , 0 ,
254
290
ctx.getASTContext ()));
255
291
256
- // An associated type term at the root means we have a dependent
292
+ // An associated type symbol at the root means we have a dependent
257
293
// member type rooted at Self; handle the associated type below.
258
294
break ;
259
295
@@ -281,17 +317,48 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
281
317
if (symbol.getKind () == Symbol::Kind::Protocol) {
282
318
#ifndef NDEBUG
283
319
// Ensure that the domain of the suffix contains P.
284
- if (begin + 1 < end) {
285
- auto protos = (begin + 1 )->getProtocols ();
320
+ if (iter + 1 < end) {
321
+ auto protos = (iter + 1 )->getProtocols ();
286
322
assert (std::find (protos.begin (), protos.end (), symbol.getProtocol ()));
287
323
}
288
324
#endif
289
325
continue ;
290
326
}
291
327
292
328
// We should have a resolved type at this point.
293
- auto *assocType =
294
- ctx.getAssociatedTypeForSymbol (symbol);
329
+ AssociatedTypeDecl *assocType;
330
+
331
+ if (begin == iter) {
332
+ // FIXME: Eliminate this case once merged associated types are gone.
333
+ assocType = ctx.getAssociatedTypeForSymbol (symbol);
334
+ } else {
335
+ // The protocol stored in an associated type symbol appearing in a
336
+ // canonical term is not necessarily the right protocol to look for
337
+ // an associated type declaration to get a canonical _type_, because
338
+ // the reduction order on terms is different than the canonical order
339
+ // on types.
340
+ //
341
+ // Instead, find all protocols that the prefix conforms to, and look
342
+ // for an associated type in those protocols.
343
+ MutableTerm prefix (begin, iter);
344
+ assert (prefix.size () > 0 );
345
+
346
+ auto *props = map.lookUpProperties (prefix.rbegin (), prefix.rend ());
347
+ assert (props != nullptr );
348
+
349
+ // Assert that the associated type's protocol appears among the set
350
+ // of protocols that the prefix conforms to.
351
+ #ifndef NDEBUG
352
+ auto conformsTo = props->getConformsTo ();
353
+ for (auto *otherProto : symbol.getProtocols ()) {
354
+ assert (std::find (conformsTo.begin (), conformsTo.end (), otherProto)
355
+ != conformsTo.end ());
356
+ }
357
+ #endif
358
+
359
+ assocType = props->getAssociatedType (symbol.getName ());
360
+ }
361
+
295
362
result = DependentMemberType::get (result, assocType);
296
363
}
297
364
0 commit comments