1717
1818#include " CodeSynthesis.h"
1919#include " TypeChecker.h"
20+ #include " TypeCheckType.h"
21+ #include " llvm/ADT/SmallPtrSet.h"
2022#include " swift/AST/AutoDiff.h"
2123#include " swift/AST/Decl.h"
2224#include " swift/AST/Expr.h"
@@ -627,6 +629,49 @@ deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) {
627629 return propDecl;
628630}
629631
632+ // / Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`.
633+ // /
634+ // / Precondition: `decl` is a nominal type decl or an extension decl.
635+ void getInheritedProtocols (Decl *decl, SmallPtrSetImpl<ProtocolDecl *> &protos) {
636+ ArrayRef<TypeLoc> inheritedTypeLocs;
637+ if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
638+ inheritedTypeLocs = nominalDecl->getInherited ();
639+ else if (auto *extDecl = dyn_cast<ExtensionDecl>(decl))
640+ inheritedTypeLocs = extDecl->getInherited ();
641+ else
642+ llvm_unreachable (" conformance is not a nominal or an extension" );
643+
644+ std::function<void (Type)> handleInheritedType;
645+
646+ auto handleProto = [&](ProtocolType *proto) -> void {
647+ proto->getDecl ()->walkInheritedProtocols ([&](ProtocolDecl *p) -> TypeWalker::Action {
648+ protos.insert (p);
649+ return TypeWalker::Action::Continue;
650+ });
651+ };
652+
653+ auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void {
654+ for (auto ty : comp->getMembers ())
655+ handleInheritedType (ty);
656+ };
657+
658+ handleInheritedType = [&](Type ty) -> void {
659+ if (auto *proto = ty->getAs <ProtocolType>())
660+ handleProto (proto);
661+ else if (auto *comp = ty->getAs <ProtocolCompositionType>())
662+ handleProtoComp (comp);
663+ };
664+
665+ for (auto loc : inheritedTypeLocs) {
666+ if (loc.getTypeRepr ())
667+ handleInheritedType (TypeResolution::forStructural (
668+ cast<DeclContext>(decl), None, /* unboundTyOpener*/ nullptr )
669+ .resolveType (loc.getTypeRepr ()));
670+ else
671+ handleInheritedType (loc.getType ());
672+ }
673+ }
674+
630675// / Return associated `TangentVector` struct for a nominal type, if it exists.
631676// / If not, synthesize the struct.
632677static StructDecl *
@@ -646,22 +691,43 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
646691 }
647692
648693 // Otherwise, synthesize a new struct.
649- auto *diffableProto = C.getProtocol (KnownProtocolKind::Differentiable);
650- auto diffableType = TypeLoc::withoutLoc (diffableProto->getDeclaredInterfaceType ());
651- auto *addArithProto = C.getProtocol (KnownProtocolKind::AdditiveArithmetic);
652- auto addArithType = TypeLoc::withoutLoc (addArithProto->getDeclaredInterfaceType ());
653694
654- // By definition, `TangentVector` must conform to `Differentiable` and
655- // `AdditiveArithmetic`.
656- SmallVector<TypeLoc, 4 > inherited{diffableType, addArithType};
695+ // Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must
696+ // inherit, by collecting all the `TangentVector` conformance requirements imposed by the
697+ // protocols that `derived.ConformanceDecl` inherits.
698+ //
699+ // Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
700+ // the `Differentiable` protocol itself requires that its `TangentVector` conforms to
701+ // `AdditiveArithmetic` and `Differentiable`.
702+ llvm::SmallPtrSet<ProtocolType *, 4 > tvDesiredProtos;
703+ llvm::SmallPtrSet<ProtocolDecl *, 4 > conformanceInheritedProtos;
704+ getInheritedProtocols (derived.ConformanceDecl , conformanceInheritedProtos);
705+ auto *diffableProto = C.getProtocol (KnownProtocolKind::Differentiable);
706+ auto *tvAssocType = diffableProto->getAssociatedType (C.Id_TangentVector );
707+ for (auto proto : conformanceInheritedProtos) {
708+ for (auto req : proto->getRequirementSignature ()) {
709+ if (req.getKind () != RequirementKind::Conformance)
710+ continue ;
711+ auto *firstType = req.getFirstType ()->getAs <DependentMemberType>();
712+ if (!firstType || firstType->getAssocType () != tvAssocType)
713+ continue ;
714+ auto tvRequiredProto = req.getSecondType ()->getAs <ProtocolType>();
715+ if (!tvRequiredProto)
716+ continue ;
717+ tvDesiredProtos.insert (tvRequiredProto);
718+ }
719+ }
720+ SmallVector<TypeLoc, 4 > tvDesiredProtoTypeLocs;
721+ for (auto *p : tvDesiredProtos)
722+ tvDesiredProtoTypeLocs.push_back (TypeLoc::withoutLoc (p));
657723
658724 // Cache original members and their associated types for later use.
659725 SmallVector<VarDecl *, 8 > diffProperties;
660726 getStoredPropertiesForDifferentiation (nominal, parentDC, diffProperties);
661727
662728 auto *structDecl =
663729 new (C) StructDecl (SourceLoc (), C.Id_TangentVector , SourceLoc (),
664- /* Inherited*/ C.AllocateCopy (inherited ),
730+ /* Inherited*/ C.AllocateCopy (tvDesiredProtoTypeLocs ),
665731 /* GenericParams*/ {}, parentDC);
666732 structDecl->setImplicit ();
667733 structDecl->copyFormalAccessFrom (nominal, /* sourceIsParentContext*/ true );
0 commit comments