Skip to content

Commit 9e9de1d

Browse files
author
Marc Rasi
committed
inherit required protocols during TangentVector synthesis
1 parent 6722d71 commit 9e9de1d

File tree

2 files changed

+98
-14
lines changed

2 files changed

+98
-14
lines changed

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
#include "CodeSynthesis.h"
1919
#include "TypeChecker.h"
20+
#include "TypeCheckType.h"
21+
#include "llvm/ADT/SmallPtrSet.h"
2022
#include "swift/AST/AutoDiff.h"
2123
#include "swift/AST/Decl.h"
2224
#include "swift/AST/Expr.h"
@@ -627,6 +629,49 @@ deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) {
627629
return propDecl;
628630
}
629631

632+
/// Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`.
633+
///
634+
/// Precondition: `decl` is a nominal type decl or an extension decl.
635+
void getInheritedProtocols(Decl *decl, SmallPtrSetImpl<ProtocolDecl *> &protos) {
636+
ArrayRef<TypeLoc> inheritedTypeLocs;
637+
if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
638+
inheritedTypeLocs = nominalDecl->getInherited();
639+
else if (auto *extDecl = dyn_cast<ExtensionDecl>(decl))
640+
inheritedTypeLocs = extDecl->getInherited();
641+
else
642+
llvm_unreachable("conformance is not a nominal or an extension");
643+
644+
std::function<void(Type)> handleInheritedType;
645+
646+
auto handleProto = [&](ProtocolType *proto) -> void {
647+
proto->getDecl()->walkInheritedProtocols([&](ProtocolDecl *p) -> TypeWalker::Action {
648+
protos.insert(p);
649+
return TypeWalker::Action::Continue;
650+
});
651+
};
652+
653+
auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void {
654+
for (auto ty : comp->getMembers())
655+
handleInheritedType(ty);
656+
};
657+
658+
handleInheritedType = [&](Type ty) -> void {
659+
if (auto *proto = ty->getAs<ProtocolType>())
660+
handleProto(proto);
661+
else if (auto *comp = ty->getAs<ProtocolCompositionType>())
662+
handleProtoComp(comp);
663+
};
664+
665+
for (auto loc : inheritedTypeLocs) {
666+
if (loc.getTypeRepr())
667+
handleInheritedType(TypeResolution::forStructural(
668+
cast<DeclContext>(decl), None, /*unboundTyOpener*/ nullptr)
669+
.resolveType(loc.getTypeRepr()));
670+
else
671+
handleInheritedType(loc.getType());
672+
}
673+
}
674+
630675
/// Return associated `TangentVector` struct for a nominal type, if it exists.
631676
/// If not, synthesize the struct.
632677
static StructDecl *
@@ -646,22 +691,43 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
646691
}
647692

648693
// Otherwise, synthesize a new struct.
649-
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
650-
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredInterfaceType());
651-
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
652-
auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredInterfaceType());
653694

654-
// By definition, `TangentVector` must conform to `Differentiable` and
655-
// `AdditiveArithmetic`.
656-
SmallVector<TypeLoc, 4> inherited{diffableType, addArithType};
695+
// Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must
696+
// inherit, by collecting all the `TangentVector` conformance requirements imposed by the
697+
// protocols that `derived.ConformanceDecl` inherits.
698+
//
699+
// Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
700+
// the `Differentiable` protocol itself requires that its `TangentVector` conforms to
701+
// `AdditiveArithmetic` and `Differentiable`.
702+
llvm::SmallPtrSet<ProtocolType *, 4> tvDesiredProtos;
703+
llvm::SmallPtrSet<ProtocolDecl *, 4> conformanceInheritedProtos;
704+
getInheritedProtocols(derived.ConformanceDecl, conformanceInheritedProtos);
705+
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
706+
auto *tvAssocType = diffableProto->getAssociatedType(C.Id_TangentVector);
707+
for (auto proto : conformanceInheritedProtos) {
708+
for (auto req : proto->getRequirementSignature()) {
709+
if (req.getKind() != RequirementKind::Conformance)
710+
continue;
711+
auto *firstType = req.getFirstType()->getAs<DependentMemberType>();
712+
if (!firstType || firstType->getAssocType() != tvAssocType)
713+
continue;
714+
auto tvRequiredProto = req.getSecondType()->getAs<ProtocolType>();
715+
if (!tvRequiredProto)
716+
continue;
717+
tvDesiredProtos.insert(tvRequiredProto);
718+
}
719+
}
720+
SmallVector<TypeLoc, 4> tvDesiredProtoTypeLocs;
721+
for (auto *p : tvDesiredProtos)
722+
tvDesiredProtoTypeLocs.push_back(TypeLoc::withoutLoc(p));
657723

658724
// Cache original members and their associated types for later use.
659725
SmallVector<VarDecl *, 8> diffProperties;
660726
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
661727

662728
auto *structDecl =
663729
new (C) StructDecl(SourceLoc(), C.Id_TangentVector, SourceLoc(),
664-
/*Inherited*/ C.AllocateCopy(inherited),
730+
/*Inherited*/ C.AllocateCopy(tvDesiredProtoTypeLocs),
665731
/*GenericParams*/ {}, parentDC);
666732
structDecl->setImplicit();
667733
structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);

test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct GenericTangentVectorMember<T: Differentiable>: Differentiable,
88
var x: T.TangentVector
99
}
1010

11-
// CHECK-AST-LABEL: internal struct GenericTangentVectorMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
11+
// CHECK-AST-LABEL: internal struct GenericTangentVectorMember<T> : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} where T : Differentiable
1212
// CHECK-AST: internal var x: T.TangentVector
1313
// CHECK-AST: internal init(x: T.TangentVector)
1414
// CHECK-AST: internal typealias TangentVector = GenericTangentVectorMember<T>
@@ -62,15 +62,15 @@ final class AdditiveArithmeticClass<T: AdditiveArithmetic & Differentiable>: Add
6262

6363
// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable {
6464
// CHECK-AST: final internal var x: T, y: T
65-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic
65+
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}}
6666
// CHECK-AST: }
6767

6868
@frozen
6969
public struct FrozenStruct: Differentiable {}
7070

7171
// CHECK-AST-LABEL: @frozen public struct FrozenStruct : Differentiable {
7272
// CHECK-AST: internal init()
73-
// CHECK-AST: @frozen public struct TangentVector : Differentiable, AdditiveArithmetic {
73+
// CHECK-AST: @frozen public struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
7474

7575
@usableFromInline
7676
struct UsableFromInlineStruct: Differentiable {}
@@ -79,7 +79,7 @@ struct UsableFromInlineStruct: Differentiable {}
7979
// CHECK-AST: struct UsableFromInlineStruct : Differentiable {
8080
// CHECK-AST: internal init()
8181
// CHECK-AST: @usableFromInline
82-
// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic {
82+
// CHECK-AST: struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
8383

8484
// Test property wrappers.
8585

@@ -96,7 +96,7 @@ struct WrappedPropertiesStruct: Differentiable {
9696
}
9797

9898
// CHECK-AST-LABEL: internal struct WrappedPropertiesStruct : Differentiable {
99-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
99+
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
100100
// CHECK-AST: internal var x: Float.TangentVector
101101
// CHECK-AST: internal var y: Float.TangentVector
102102
// CHECK-AST: internal var z: Float.TangentVector
@@ -111,9 +111,27 @@ class WrappedPropertiesClass: Differentiable {
111111
}
112112

113113
// CHECK-AST-LABEL: internal class WrappedPropertiesClass : Differentiable {
114-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
114+
// CHECK-AST: internal struct TangentVector : {{(Differentiable, AdditiveArithmetic)|(AdditiveArithmetic, Differentiable)}} {
115115
// CHECK-AST: internal var x: Float.TangentVector
116116
// CHECK-AST: internal var y: Float.TangentVector
117117
// CHECK-AST: internal var z: Float.TangentVector
118118
// CHECK-AST: }
119119
// CHECK-AST: }
120+
121+
protocol TangentVectorMustBeEncodable: Differentiable where TangentVector: Encodable {}
122+
123+
struct AutoDeriveEncodableTV1: TangentVectorMustBeEncodable {
124+
var x: Float
125+
}
126+
127+
// CHECK-AST-LABEL: internal struct AutoDeriveEncodableTV1 : TangentVectorMustBeEncodable {
128+
// CHECK-AST: internal struct TangentVector : {{(Encodable, Differentiable, AdditiveArithmetic)|(Encodable, AdditiveArithmetic, Differentiable)|(Differentiable, Encodable, AdditiveArithmetic)|(AdditiveArithmetic, Encodable, Differentiable)|(Differentiable, AdditiveArithmetic, Encodable)|(AdditiveArithmetic, Differentiable, Encodable)}} {
129+
130+
struct AutoDeriveEncodableTV2 {
131+
var x: Float
132+
}
133+
134+
extension AutoDeriveEncodableTV2: TangentVectorMustBeEncodable {}
135+
136+
// CHECK-AST-LABEL: extension AutoDeriveEncodableTV2 : TangentVectorMustBeEncodable {
137+
// CHECK-AST: internal struct TangentVector : {{(Encodable, Differentiable, AdditiveArithmetic)|(Encodable, AdditiveArithmetic, Differentiable)|(Differentiable, Encodable, AdditiveArithmetic)|(AdditiveArithmetic, Encodable, Differentiable)|(Differentiable, AdditiveArithmetic, Encodable)|(AdditiveArithmetic, Differentiable, Encodable)}} {

0 commit comments

Comments
 (0)