17
17
18
18
#include " CodeSynthesis.h"
19
19
#include " TypeChecker.h"
20
+ #include " TypeCheckType.h"
21
+ #include " llvm/ADT/SmallPtrSet.h"
20
22
#include " swift/AST/AutoDiff.h"
21
23
#include " swift/AST/Decl.h"
22
24
#include " swift/AST/Expr.h"
@@ -627,6 +629,49 @@ deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) {
627
629
return propDecl;
628
630
}
629
631
632
+ // / Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`.
633
+ // /
634
+ // / Precondition: `decl` is a nominal type decl or an extension decl.
635
+ void getInheritedProtocols (Decl *decl, SmallPtrSetImpl<ProtocolDecl *> &protos) {
636
+ ArrayRef<TypeLoc> inheritedTypeLocs;
637
+ if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
638
+ inheritedTypeLocs = nominalDecl->getInherited ();
639
+ else if (auto *extDecl = dyn_cast<ExtensionDecl>(decl))
640
+ inheritedTypeLocs = extDecl->getInherited ();
641
+ else
642
+ llvm_unreachable (" conformance is not a nominal or an extension" );
643
+
644
+ std::function<void (Type)> handleInheritedType;
645
+
646
+ auto handleProto = [&](ProtocolType *proto) -> void {
647
+ proto->getDecl ()->walkInheritedProtocols ([&](ProtocolDecl *p) -> TypeWalker::Action {
648
+ protos.insert (p);
649
+ return TypeWalker::Action::Continue;
650
+ });
651
+ };
652
+
653
+ auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void {
654
+ for (auto ty : comp->getMembers ())
655
+ handleInheritedType (ty);
656
+ };
657
+
658
+ handleInheritedType = [&](Type ty) -> void {
659
+ if (auto *proto = ty->getAs <ProtocolType>())
660
+ handleProto (proto);
661
+ else if (auto *comp = ty->getAs <ProtocolCompositionType>())
662
+ handleProtoComp (comp);
663
+ };
664
+
665
+ for (auto loc : inheritedTypeLocs) {
666
+ if (loc.getTypeRepr ())
667
+ handleInheritedType (TypeResolution::forStructural (
668
+ cast<DeclContext>(decl), None, /* unboundTyOpener*/ nullptr )
669
+ .resolveType (loc.getTypeRepr ()));
670
+ else
671
+ handleInheritedType (loc.getType ());
672
+ }
673
+ }
674
+
630
675
// / Return associated `TangentVector` struct for a nominal type, if it exists.
631
676
// / If not, synthesize the struct.
632
677
static StructDecl *
@@ -646,22 +691,43 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
646
691
}
647
692
648
693
// Otherwise, synthesize a new struct.
649
- auto *diffableProto = C.getProtocol (KnownProtocolKind::Differentiable);
650
- auto diffableType = TypeLoc::withoutLoc (diffableProto->getDeclaredInterfaceType ());
651
- auto *addArithProto = C.getProtocol (KnownProtocolKind::AdditiveArithmetic);
652
- auto addArithType = TypeLoc::withoutLoc (addArithProto->getDeclaredInterfaceType ());
653
694
654
- // By definition, `TangentVector` must conform to `Differentiable` and
655
- // `AdditiveArithmetic`.
656
- SmallVector<TypeLoc, 4 > inherited{diffableType, addArithType};
695
+ // Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must
696
+ // inherit, by collecting all the `TangentVector` conformance requirements imposed by the
697
+ // protocols that `derived.ConformanceDecl` inherits.
698
+ //
699
+ // Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
700
+ // the `Differentiable` protocol itself requires that its `TangentVector` conforms to
701
+ // `AdditiveArithmetic` and `Differentiable`.
702
+ llvm::SmallPtrSet<ProtocolType *, 4 > tvDesiredProtos;
703
+ llvm::SmallPtrSet<ProtocolDecl *, 4 > conformanceInheritedProtos;
704
+ getInheritedProtocols (derived.ConformanceDecl , conformanceInheritedProtos);
705
+ auto *diffableProto = C.getProtocol (KnownProtocolKind::Differentiable);
706
+ auto *tvAssocType = diffableProto->getAssociatedType (C.Id_TangentVector );
707
+ for (auto proto : conformanceInheritedProtos) {
708
+ for (auto req : proto->getRequirementSignature ()) {
709
+ if (req.getKind () != RequirementKind::Conformance)
710
+ continue ;
711
+ auto *firstType = req.getFirstType ()->getAs <DependentMemberType>();
712
+ if (!firstType || firstType->getAssocType () != tvAssocType)
713
+ continue ;
714
+ auto tvRequiredProto = req.getSecondType ()->getAs <ProtocolType>();
715
+ if (!tvRequiredProto)
716
+ continue ;
717
+ tvDesiredProtos.insert (tvRequiredProto);
718
+ }
719
+ }
720
+ SmallVector<TypeLoc, 4 > tvDesiredProtoTypeLocs;
721
+ for (auto *p : tvDesiredProtos)
722
+ tvDesiredProtoTypeLocs.push_back (TypeLoc::withoutLoc (p));
657
723
658
724
// Cache original members and their associated types for later use.
659
725
SmallVector<VarDecl *, 8 > diffProperties;
660
726
getStoredPropertiesForDifferentiation (nominal, parentDC, diffProperties);
661
727
662
728
auto *structDecl =
663
729
new (C) StructDecl (SourceLoc (), C.Id_TangentVector , SourceLoc (),
664
- /* Inherited*/ C.AllocateCopy (inherited ),
730
+ /* Inherited*/ C.AllocateCopy (tvDesiredProtoTypeLocs ),
665
731
/* GenericParams*/ {}, parentDC);
666
732
structDecl->setImplicit ();
667
733
structDecl->copyFormalAccessFrom (nominal, /* sourceIsParentContext*/ true );
0 commit comments