@@ -337,51 +337,6 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
337
337
C.TheEmptyTupleType , {deriveBodyDifferentiable_move, nullptr });
338
338
}
339
339
340
- // / Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`.
341
- // /
342
- // / Precondition: `decl` is a nominal type decl or an extension decl.
343
- void getInheritedProtocols (Decl *decl, SmallPtrSetImpl<ProtocolDecl *> &protos) {
344
- ArrayRef<TypeLoc> inheritedTypeLocs;
345
- if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
346
- inheritedTypeLocs = nominalDecl->getInherited ();
347
- else if (auto *extDecl = dyn_cast<ExtensionDecl>(decl))
348
- inheritedTypeLocs = extDecl->getInherited ();
349
- else
350
- llvm_unreachable (" conformance is not a nominal or an extension" );
351
-
352
- std::function<void (Type)> handleInheritedType;
353
-
354
- auto handleProto = [&](ProtocolType *proto) -> void {
355
- proto->getDecl ()->walkInheritedProtocols ([&](ProtocolDecl *p) -> TypeWalker::Action {
356
- protos.insert (p);
357
- return TypeWalker::Action::Continue;
358
- });
359
- };
360
-
361
- auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void {
362
- for (auto ty : comp->getMembers ())
363
- handleInheritedType (ty);
364
- };
365
-
366
- handleInheritedType = [&](Type ty) -> void {
367
- if (auto *proto = ty->getAs <ProtocolType>())
368
- handleProto (proto);
369
- else if (auto *comp = ty->getAs <ProtocolCompositionType>())
370
- handleProtoComp (comp);
371
- };
372
-
373
- for (auto loc : inheritedTypeLocs) {
374
- if (loc.getTypeRepr ())
375
- handleInheritedType (
376
- TypeResolution::forStructural (cast<DeclContext>(decl), None,
377
- /* unboundTyOpener*/ nullptr ,
378
- /* placeholderHandler*/ nullptr )
379
- .resolveType (loc.getTypeRepr ()));
380
- else
381
- handleInheritedType (loc.getType ());
382
- }
383
- }
384
-
385
340
// / Return associated `TangentVector` struct for a nominal type, if it exists.
386
341
// / If not, synthesize the struct.
387
342
static StructDecl *
@@ -409,12 +364,14 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
409
364
// Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
410
365
// the `Differentiable` protocol itself requires that its `TangentVector` conforms to
411
366
// `AdditiveArithmetic` and `Differentiable`.
412
- llvm::SmallPtrSet<ProtocolDecl *, 4 > tvDesiredProtos;
413
- llvm::SmallPtrSet<ProtocolDecl *, 4 > conformanceInheritedProtos;
414
- getInheritedProtocols (derived.ConformanceDecl , conformanceInheritedProtos);
367
+ llvm::SmallSetVector<ProtocolDecl *, 4 > tvDesiredProtos;
368
+
415
369
auto *diffableProto = C.getProtocol (KnownProtocolKind::Differentiable);
416
370
auto *tvAssocType = diffableProto->getAssociatedType (C.Id_TangentVector );
417
- for (auto proto : conformanceInheritedProtos) {
371
+
372
+ auto localProtos = cast<IterableDeclContext>(derived.ConformanceDecl )
373
+ ->getLocalProtocols ();
374
+ for (auto proto : localProtos) {
418
375
for (auto req : proto->getRequirementSignature ()) {
419
376
if (req.getKind () != RequirementKind::Conformance)
420
377
continue ;
0 commit comments